Skip to content

Commit a20e2f0

Browse files
author
csAugust
committed
Add speculative decoding support to Qwen 1.5 for CPU backend on Linux
1 parent b2a4a21 commit a20e2f0

18 files changed

+1946
-6
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func_llm_add_executable(demo_bert)
7878
func_llm_add_executable(demo_phonelm)
7979
func_llm_add_executable(demo_llama3)
8080
func_llm_add_executable(demo_minicpm_moe_mbm)
81-
81+
func_llm_add_executable(demo_qwen_sd)
8282

8383

8484
func_vlm_add_executable(demo_llava)

examples/demo_qwen_sd.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* @file demo_qwen_sd.cpp
3+
* @author Zhiyang Chen ([email protected])
4+
* @brief
5+
* @date 2025-3-11
6+
*
7+
*
8+
*/
9+
#include "cmdline.h"
10+
#include "models/qwen/configuration_qwen.hpp"
11+
#include "models/qwen/modeling_qwen_sd.hpp"
12+
#include "models/qwen/tokenization_qwen.hpp"
13+
14+
using namespace mllm;
15+
16+
int main(int argc, char **argv) {
17+
std::iostream::sync_with_stdio(false);
18+
19+
cmdline::parser cmdParser;
20+
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen_vocab.mllm");
21+
cmdParser.add<string>("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen_merges.txt");
22+
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-1.8b-q8_0.mllm");
23+
cmdParser.add<string>("billion", 'b', "[0.5B | 1.8B | 1.5B |]", false, "1.8B");
24+
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
25+
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
26+
cmdParser.parse_check(argc, argv);
27+
28+
string vocab_path = cmdParser.get<string>("vocab");
29+
string merge_path = cmdParser.get<string>("merge");
30+
string model_path = cmdParser.get<string>("model");
31+
string model_billion = cmdParser.get<string>("billion");
32+
int tokens_limit = cmdParser.get<int>("limits");
33+
CPUBackend::cpu_threads = cmdParser.get<int>("thread");
34+
35+
auto tokenizer = QWenTokenizer(vocab_path, merge_path);
36+
QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE);
37+
auto model = QWenForCausalLM(config);
38+
model.load(model_path);
39+
40+
vector<string> in_strs = {
41+
"Summarize: Hillary Clinton\u2019s security detail arrived at a suburban Des Moines, Iowa fruit processing company on Tuesday with an added vehicle \u2013 a second Scooby. After her signature oversize black Chevy conversion van dropped her off at Capitol Fruit Company in Norwalk, Iowa, a visually identical GMC van drove up to the building with a nearly identical Secret Service escort vehicle. Both armored vehicles have raised roofs, deep-tinted windows and New York license plates. But while the original van \u2013 the one nicknamed 'Scooby' after the Scooby-Doo cartoon show \u2013 sports a mustard-yellow New York tag, the second has blue and white plates of a different design. Scroll down for video. WHY BUY ONE WHEN YOU CAN HAVE TWO AT TWICE THE PRICE? The first picture of both of Hillary Clinton's Scooby mobiles. One is a GMC and the other is a Chevrolet, but they are mechanically identical. CONVOY: Scooby-one and Scooby-two took up positions in Hillary's motorcade on a freeway near Des Moines",
42+
"Hello, who are you?",
43+
"What can you do?",
44+
"Please introduce Beijing University of Posts and Telecommunications.",
45+
};
46+
for (int i = 0; i < in_strs.size(); ++i) {
47+
auto input_str = tokenizer.apply_chat_template(in_strs[i]);
48+
auto input_tensor = tokenizer.tokenize(input_str);
49+
std::cout << "[Q] " << in_strs[i] << std::endl;
50+
std::cout << "[A] " << std::flush;
51+
52+
LlmTextGeneratorOpts opt{
53+
.max_new_tokens = 50,
54+
.do_sample = false, // TODO 实现投机解码的核采样
55+
};
56+
model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool {
57+
auto out_string = tokenizer.detokenize({out_token});
58+
auto [not_end, output_string] = tokenizer.postprocess(out_string);
59+
if (!not_end) { return false; }
60+
std::cout << output_string << std::flush;
61+
return true;
62+
});
63+
std::cout << "\n";
64+
}
65+
}

include/OpDefined.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ enum OpType {
7575
// new front-end
7676
SUPERSILU,
7777
HEADLINEAR,
78+
79+
// for speculative decoding
80+
ROPETREE,
81+
CAUSALTREEMASK,
7882
};
7983

8084
static const vector<string> OpNames = {

src/Draft.hpp

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
/**
2+
* @file Draft.hpp
3+
* @author Zhiyang Chen ([email protected])
4+
* @brief
5+
* @date 2025-2-24
6+
*
7+
*
8+
*/
9+
#pragma once
10+
#ifndef MLLM_DRAFT_HPP
11+
#define MLLM_DRAFT_HPP
12+
#include <iostream>
13+
#include <vector>
14+
#include <unordered_map>
15+
#include <string>
16+
#include <deque>
17+
#include <algorithm>
18+
#include <cassert>
19+
20+
namespace mllm {
21+
22+
23+
class TracePool {
24+
public:
25+
struct Trace {
26+
std::vector<unsigned int> trace_tokens;
27+
Trace(const std::vector<unsigned int> &tokens) : trace_tokens(tokens) {}
28+
};
29+
30+
void add_trace(const std::vector<unsigned int> &tokens) {
31+
if (tokens.empty()) {
32+
return;
33+
}
34+
traces.push_back(Trace(tokens));
35+
}
36+
37+
void clear_trace() {
38+
traces.clear();
39+
}
40+
41+
void reset() {
42+
is_decoding = false;
43+
draft_length = 0;
44+
last_accept_cid = 0;
45+
last_accept_length = 0;
46+
last_draft_length = 0;
47+
traces.clear();
48+
last_accept_position_ids.clear();
49+
trace_position_ids.clear();
50+
}
51+
52+
inline const Trace& get_accepted_trace() {
53+
return traces[last_accept_cid];
54+
}
55+
inline unsigned int get_accepted_length() {
56+
return last_accept_length;
57+
}
58+
inline unsigned int get_draft_length() {
59+
return draft_length;
60+
}
61+
// inline unsigned int get_n_trace() {
62+
// return traces.size();
63+
// }
64+
65+
unsigned int evalPosterior(const std::vector<std::vector<float>> &logit_scores, const std::vector<unsigned int> &sampled_token_ids) {
66+
std::vector<unsigned int> accept_lengths;
67+
int n_candidate = traces.size();
68+
unsigned int best_candidate_idx = 0;
69+
unsigned int max_accept_length = 0;
70+
unsigned int best_next_token_id = sampled_token_ids[0];
71+
72+
int idx_offset = 0; // draft token被放到input_ids后的偏移量
73+
for (int tid = 0; tid < n_candidate; tid++) {
74+
const std::vector<unsigned int> &trace_tokens = traces[tid].trace_tokens;
75+
unsigned int trace_length = trace_tokens.size();
76+
unsigned int accept_length = 0;
77+
for (int i = 0; i < trace_length; i++) {
78+
int src_idx = i;
79+
int tgt_idx = (i == 0)? (0) : (idx_offset + i);
80+
if (trace_tokens[src_idx] == sampled_token_ids[tgt_idx]) {
81+
accept_length += 1;
82+
} else {
83+
break;
84+
}
85+
}
86+
if (accept_length > max_accept_length) {
87+
max_accept_length = accept_length;
88+
best_candidate_idx = tid;
89+
best_next_token_id = sampled_token_ids[idx_offset + accept_length];
90+
}
91+
idx_offset += trace_length;
92+
accept_lengths.push_back(accept_length);
93+
}
94+
95+
this->last_draft_length = this->draft_length;
96+
this->last_accept_cid = best_candidate_idx;
97+
this->last_accept_length = max_accept_length;
98+
this->last_accept_position_ids.clear();
99+
for (int i = 0; i < max_accept_length; i++) {
100+
this->last_accept_position_ids.push_back(this->trace_position_ids[best_candidate_idx][i]);
101+
}
102+
// std::cout << "Accept length: " << max_accept_length << std::endl;
103+
return best_next_token_id;
104+
}
105+
106+
107+
unsigned int generate_draft(std::vector<unsigned int> &input_ids, std::vector<unsigned int> &position_ids,
108+
std::vector<int> &tree_ancestors, unsigned int cur_seq_length) {
109+
unsigned int draft_len = 0;
110+
this->trace_position_ids.clear();
111+
for (int i = 0; i < traces.size(); i++) {
112+
unsigned int trace_len = traces[i].trace_tokens.size();
113+
input_ids.insert(input_ids.end(), traces[i].trace_tokens.begin(), traces[i].trace_tokens.end());
114+
tree_ancestors.push_back(0); // 每个trace的首节点总是指向start token
115+
std::vector<unsigned int> pos;
116+
for (int j = 0; j < trace_len; j++) {
117+
position_ids.push_back(draft_len + j + cur_seq_length);
118+
pos.push_back(draft_len + j + cur_seq_length);
119+
if (j > 0) {
120+
tree_ancestors.push_back(draft_len + j);
121+
}
122+
}
123+
this->trace_position_ids.push_back(pos);
124+
draft_len += trace_len;
125+
}
126+
this->draft_length = draft_len;
127+
return draft_len;
128+
}
129+
130+
std::vector<Trace> traces;
131+
bool is_decoding = false;
132+
unsigned int draft_length = 0; // draft部分的总长度
133+
// 记录上一次verify的结果
134+
unsigned int last_accept_cid = 0;
135+
unsigned int last_accept_length = 0;
136+
unsigned int last_draft_length = 0;
137+
std::vector<unsigned int> last_accept_position_ids;
138+
std::vector<std::vector<unsigned int>> trace_position_ids;
139+
140+
private:
141+
// std::vector<std::vector<unsigned int>> candidate_token_ids;
142+
// std::vector<std::vector<unsigned int>> candidate_position_ids;
143+
// std::map<unsigned int, std::vector<unsigned int>> cid2pids;
144+
// std::vector<int> tree_ancestors;
145+
146+
};
147+
148+
149+
class SuffixAutomaton {
150+
public:
151+
struct State {
152+
std::unordered_map<int, int> next; // 存储字符ID对应的转移状态
153+
int link = -1; // 后缀链接
154+
int length = 0; // 当前状态的长度
155+
int min_endpos = 0; // 当前状态的最小结束位置
156+
State() = default;
157+
State(int link, int length, int min_endpos) : link(link), length(length), min_endpos(min_endpos) {}
158+
};
159+
160+
SuffixAutomaton() {
161+
states.push_back(State(-1, 0, 0)); // 重新初始化状态
162+
input_ids.push_back(-1);
163+
last = 0;
164+
max_length = 0;
165+
cur_index = 0;
166+
cur_length = 0;
167+
}
168+
169+
void reset() {
170+
states.clear();
171+
states.push_back(State(-1, 0, 0));
172+
input_ids.clear();
173+
input_ids.push_back(-1);
174+
last = 0;
175+
max_length = 0;
176+
cur_index = 0;
177+
cur_length = 0;
178+
}
179+
180+
void add_tokens(const std::vector<unsigned int>& tokens) {
181+
for (unsigned int token : tokens) {
182+
transfer_cur_state(token);
183+
add_state(token);
184+
}
185+
input_ids.insert(input_ids.end(), tokens.begin(), tokens.end());
186+
}
187+
188+
std::pair<int, int> lookup(int start_token) const {
189+
int index = cur_index;
190+
int length = cur_length;
191+
transfer_state(index, length, start_token);
192+
return {index, length};
193+
}
194+
195+
int gen_draft(std::vector<unsigned int> &seq, int index, int match_length, unsigned int start_token, int minimum_length = 0) {
196+
int n = std::min(max_predicts, 1 + static_cast<int>(match_length * alpha));
197+
if (minimum_length > 0 && match_length > 0) {
198+
n = std::max(minimum_length, n);
199+
}
200+
int endpos = states[index].min_endpos;
201+
seq.clear();
202+
for (int i = endpos + 1; i < endpos + n; ++i) {
203+
if (i >= input_ids.size()) break;
204+
seq.push_back(input_ids[i]);
205+
}
206+
return seq.size();
207+
}
208+
209+
void print() const {
210+
for (size_t i = 1; i < states.size(); ++i) {
211+
std::cout << "State " << i << ": length = " << states[i].length << ", link = " << states[i].link << ", min_endpos = " << states[i].min_endpos << std::endl;
212+
for (const auto& [ch, next_state] : states[i].next) {
213+
std::cout << " " << char('a' + ch) << " -> " << next_state << std::endl;
214+
}
215+
}
216+
}
217+
218+
private:
219+
std::vector<State> states;
220+
int last;
221+
int max_length;
222+
int cur_index = 0;
223+
int cur_length = 0;
224+
int max_predicts = 40;
225+
float alpha = 4.0f;
226+
std::vector<int> input_ids;
227+
228+
unsigned int expand_state(const State &state) {
229+
unsigned int new_index = states.size();
230+
states.push_back(state);
231+
return new_index;
232+
}
233+
234+
void add_state(int c) {
235+
max_length += 1;
236+
int cur = expand_state(State(-1, max_length, max_length));
237+
int p = last;
238+
while (p != -1 && states[p].next.count(c) == 0) {
239+
states[p].next[c] = cur;
240+
p = states[p].link;
241+
}
242+
243+
if (p == -1) {
244+
states[cur].link = 0;
245+
} else {
246+
int q = states[p].next[c];
247+
if (states[p].length + 1 == states[q].length) {
248+
states[cur].link = q;
249+
} else {
250+
int clone = states.size();
251+
states.push_back(states[q]);
252+
states[clone].length = states[p].length + 1;
253+
while (p != -1 && states[p].next[c] == q) {
254+
states[p].next[c] = clone;
255+
p = states[p].link;
256+
}
257+
states[q].link = states[cur].link = clone;
258+
}
259+
}
260+
last = cur;
261+
}
262+
263+
void transfer_state(int& index, int& length, int token) const {
264+
while (index != 0 && states[index].next.count(token) == 0) {
265+
index = states[index].link;
266+
length = states[index].length;
267+
}
268+
if (states[index].next.count(token)) {
269+
index = states[index].next.at(token);
270+
length++;
271+
} else {
272+
index = length = 0;
273+
}
274+
}
275+
276+
void transfer_cur_state(int token) {
277+
transfer_state(cur_index, cur_length, token);
278+
}
279+
280+
};
281+
282+
} // namespace mllm
283+
284+
#endif //! MLLM_DRAFT_HPP

0 commit comments

Comments
 (0)