Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit df40d5e

Browse files
authored
[LLM Runtime] Add Script for PPL Evaluation (#685)
1 parent 12882d8 commit df40d5e

File tree

9 files changed

+577
-135
lines changed

9 files changed

+577
-135
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/intel_extension_for_transformers/llm/runtime/graph/*
2+
!/intel_extension_for_transformers/llm/runtime/graph/*.*
3+
!/intel_extension_for_transformers/llm/runtime/graph/*/
4+
### ignore binary files in llm-runtime ###
5+
16
*.pyc
27
.vscode
38
.idea
@@ -11,6 +16,7 @@
1116
*.log
1217
*.swp
1318
*.onnx
19+
*.bin
1420
tags
1521
build/
1622
_build
@@ -32,6 +38,8 @@ CMakeUserPresets.json
3238

3339
/intel_extension_for_transformers/llm/runtime/.vs
3440
/intel_extension_for_transformers/llm/runtime/out
41+
/intel_extension_for_transformers/llm/runtime/graph/out
42+
/intel_extension_for_transformers/llm/runtime/graph/runtime_outs
3543
/examples/**/*.npy
3644
/examples/**/*.bin
3745
/examples/**/*.yaml

intel_extension_for_transformers/llm/runtime/graph/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ class StopOnTokens(StoppingCriteria):
351351
self.min_length = min_length
352352
self.start_length = start_length
353353
self.stop_token_id = stop_token_id
354-
354+
355355
def __call__(
356356
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
357357
) -> bool:
@@ -373,3 +373,6 @@ stopping_criteria = StoppingCriteriaList(
373373

374374
outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria)
375375
```
376+
377+
### 6. Perplexity (measuring model quality)
378+
You can use the [scripts/perplexity.py](./scripts/perplexity.py) script to over a given (subset of) dataset. Run `python scripts/perplexity.py --help` for detailed usage. For more infomation of the perplexity metric, see https://huggingface.co/docs/transformers/perplexity.

intel_extension_for_transformers/llm/runtime/graph/__init__.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
import os
18-
from transformers import AutoConfig, AutoTokenizer
19-
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
18+
2019
import torch
20+
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
21+
from transformers import AutoConfig, AutoTokenizer
22+
2123
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}
2224

25+
2326
class Model:
2427
def __init__(self):
2528
self.module = None
@@ -28,55 +31,68 @@ def __init__(self):
2831
self.bin_file = None
2932
self.generate_round = 0
3033

31-
def __import_package(self, model_name):
34+
def __import_package(self, model_type):
3235
if self.module:
3336
return
34-
if model_name == "gptj":
37+
if model_type == "gptj":
3538
import intel_extension_for_transformers.llm.runtime.graph.gptj_cpp as cpp_model
36-
elif model_name == "falcon":
39+
elif model_type == "falcon":
3740
import intel_extension_for_transformers.llm.runtime.graph.falcon_cpp as cpp_model
38-
elif model_name == "gptneox":
41+
elif model_type == "gptneox":
3942
import intel_extension_for_transformers.llm.runtime.graph.gptneox_cpp as cpp_model
40-
elif model_name == "dolly":
43+
elif model_type == "dolly":
4144
import intel_extension_for_transformers.llm.runtime.graph.dolly_cpp as cpp_model
42-
elif model_name == "llama" or model_name == "llama2":
45+
elif model_type == "llama" or model_type == "llama2":
4346
import intel_extension_for_transformers.llm.runtime.graph.llama_cpp as cpp_model
44-
elif model_name == "mpt":
47+
elif model_type == "mpt":
4548
import intel_extension_for_transformers.llm.runtime.graph.mpt_cpp as cpp_model
46-
elif model_name == "gpt_bigcode" or model_name == "starcoder":
49+
elif model_type == "gpt_bigcode" or model_type == "starcoder":
4750
import intel_extension_for_transformers.llm.runtime.graph.starcoder_cpp as cpp_model
48-
elif model_name == "opt":
51+
elif model_type == "opt":
4952
import intel_extension_for_transformers.llm.runtime.graph.opt_cpp as cpp_model
50-
elif model_name == "bloom":
53+
elif model_type == "bloom":
5154
import intel_extension_for_transformers.llm.runtime.graph.bloom_cpp as cpp_model
52-
elif model_name == "chatglm":
55+
elif model_type == "chatglm":
5356
import intel_extension_for_transformers.llm.runtime.graph.chatglm_cpp as cpp_model
54-
elif model_name == "chatglm2":
57+
elif model_type == "chatglm2":
5558
import intel_extension_for_transformers.llm.runtime.graph.chatglm2_cpp as cpp_model
56-
elif model_name == "baichuan":
59+
elif model_type == "baichuan":
5760
import intel_extension_for_transformers.llm.runtime.graph.baichuan_cpp as cpp_model
58-
elif model_name == "polyglot":
61+
elif model_type == "polyglot":
5962
import intel_extension_for_transformers.llm.runtime.graph.polyglot_cpp as cpp_model
60-
elif model_name == "mistral":
63+
elif model_type == "mistral":
6164
import intel_extension_for_transformers.llm.runtime.graph.mistral_cpp as cpp_model
6265
else:
63-
raise TypeError("Unspported model type {}!".format(model_name))
66+
raise TypeError("Unspported model type {}!".format(model_type))
6467
self.module = cpp_model
6568

69+
@staticmethod
70+
def get_model_type(model_config):
71+
model_type = model_maps.get(model_config.model_type, model_config.model_type)
72+
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
73+
model_type = "chatglm2"
74+
return model_type
75+
6676
def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
6777
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
6878
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
69-
model_type = model_maps.get(self.config.model_type, self.config.model_type)
70-
if model_type == "chatglm" and "chatglm2" in self.config._name_or_path:
71-
model_type = "chatglm2"
79+
model_type = Model.get_model_type(self.config)
7280
self.__import_package(model_type)
7381

7482
# check cache and quantization
7583
output_path = "runtime_outs"
76-
if not os.path.exists(output_path):
77-
os.makedirs(output_path)
84+
os.makedirs(output_path, exist_ok=True)
7885
fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type)
79-
quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type)
86+
quant_desc = quant_kwargs['weight_dtype']
87+
if quant_kwargs['use_ggml']:
88+
quant_desc += "_ggml"
89+
else:
90+
quant_desc += "_jblas_c" + quant_kwargs['compute_dtype']
91+
if quant_kwargs['group_size'] == -1:
92+
quant_desc += "_pc"
93+
else:
94+
quant_desc += "_g{}".format(quant_kwargs['group_size'])
95+
quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc)
8096

8197
if not_quant:
8298
self.bin_file = fp32_bin
@@ -85,20 +101,22 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
85101
if use_cache and os.path.exists(self.bin_file):
86102
return
87103

88-
convert_model(model_name, fp32_bin, "f32")
89-
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
104+
if not use_cache or not os.path.exists(fp32_bin):
105+
convert_model(model_name, fp32_bin, "f32")
106+
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
90107

91108
if not_quant:
92109
print("FP32 model will be used.")
93110
return
94-
self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **quant_kwargs)
111+
self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, **quant_kwargs)
95112
assert os.path.exists(quant_bin), "Fail to quantize model"
96-
113+
97114
# clean
98-
os.remove(fp32_bin)
115+
if not use_cache:
116+
os.remove(fp32_bin)
99117

100-
def init_from_bin(self, model_name, model_path, **generate_kwargs):
101-
self.__import_package(model_name)
118+
def init_from_bin(self, model_type, model_path, **generate_kwargs):
119+
self.__import_package(model_type)
102120
self.model = self.module.Model()
103121
if "threads" not in generate_kwargs:
104122
threads = os.getenv("OMP_NUM_THREADS")
@@ -108,11 +126,9 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs):
108126
generate_kwargs["threads"] = int(threads)
109127
self.model.init_model(model_path, **generate_kwargs)
110128

111-
def quant_model(self, model_name, model_path, out_path, **quant_kwargs):
112-
self.__import_package(model_name)
113-
self.module.Model.quant_model(model_path = model_path,
114-
out_path = out_path, **quant_kwargs)
115-
129+
def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
130+
self.__import_package(model_type)
131+
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)
116132

117133
def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
118134
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
@@ -129,8 +145,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
129145
ret = input_ids.tolist()
130146

131147
beam_search = False
132-
if ("num_beams" in generate_kwargs and generate_kwargs["num_beams"] > 1) and not \
133-
generate_kwargs.get("do_sample", False):
148+
if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False):
134149
beam_search = True
135150
if not beam_search:
136151
# TODO support multi batch
@@ -142,30 +157,43 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
142157
Make sure that `num_beams` is set to 1."
143158
if self.generate_round == 0 and not ignore_prompt:
144159
streamer.put(input_ids)
145-
160+
146161
if interactive:
147162
self.model.reset_token_end()
148163
out_count = 0
164+
input_list = input_ids.tolist()
149165
while True:
150-
response = self.model.generate(input_ids = input_ids.tolist())
166+
response = self.model.generate(input_ids=input_list)
167+
input_list = [] # next-token stage will use previous output
151168
if len(response) == 0:
152169
break
153170
if streamer:
154171
streamer.put(torch.tensor([response[0]]))
155172
for i in range(len(response)):
156173
ret[i].extend(response[i])
174+
if beam_search:
175+
break
157176
if stopping_criteria is not None:
158177
if stopping_criteria(torch.tensor(ret), None):
159178
break
160179
elif ret[0][-1] == self.tokenizer.eos_token_id or \
161-
(max_new_tokens != -1 and out_count > max_new_tokens):
180+
(max_new_tokens != -1 and out_count > max_new_tokens):
162181
break
163182
out_count += 1
164183
if streamer:
165184
streamer.end()
166-
185+
167186
self.generate_round += 1
168187
return ret
169188

170189
def is_token_end(self):
171190
return self.model.is_token_end()
191+
192+
def __call__(self, input_ids, reinit=False, **kwargs):
193+
if self.model is None:
194+
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
195+
self.generate_round = 0
196+
elif reinit:
197+
self.model.reinit()
198+
self.generate_round = 0
199+
return self.model.evaluate(input_ids.tolist())

0 commit comments

Comments
 (0)