forked from b0kch01/llama-cpu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
executable file
·107 lines (88 loc) · 2.99 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int = 17,
max_batch_size: int = 1,
) -> LLaMA:
print("Creating model...")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args)
# Original copyright by tloen
# https://github.com/tloen/llama-int8/blob/main/example.py
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.to("cpu")
generator = LLaMA(model, tokenizer)
print(f"Loaded model in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 20,
max_batch_size: int = 2,
):
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = ['The thing I like best is', 'The key to the cabinets']
results = generator.generate(
prompts, max_gen_len=5, temperature=temperature, top_p=top_p
)
for result in results:
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)