1
+ /* *
2
+ * @file Draft.hpp
3
+ * @author Zhiyang Chen ([email protected] )
4
+ * @brief
5
+ * @date 2025-2-24
6
+ *
7
+ *
8
+ */
9
+ #pragma once
10
+ #ifndef MLLM_DRAFT_HPP
11
+ #define MLLM_DRAFT_HPP
12
+ #include < iostream>
13
+ #include < vector>
14
+ #include < unordered_map>
15
+ #include < string>
16
+ #include < deque>
17
+ #include < algorithm>
18
+ #include < cassert>
19
+
20
+ namespace mllm {
21
+
22
+
23
+ class TracePool {
24
+ public:
25
+ struct Trace {
26
+ std::vector<unsigned int > trace_tokens;
27
+ Trace (const std::vector<unsigned int > &tokens) : trace_tokens(tokens) {}
28
+ };
29
+
30
+ void add_trace (const std::vector<unsigned int > &tokens) {
31
+ if (tokens.empty ()) {
32
+ return ;
33
+ }
34
+ traces.push_back (Trace (tokens));
35
+ }
36
+
37
+ void clear_trace () {
38
+ traces.clear ();
39
+ }
40
+
41
+ void reset () {
42
+ is_decoding = false ;
43
+ draft_length = 0 ;
44
+ last_accept_cid = 0 ;
45
+ last_accept_length = 0 ;
46
+ last_draft_length = 0 ;
47
+ traces.clear ();
48
+ last_accept_position_ids.clear ();
49
+ trace_position_ids.clear ();
50
+ }
51
+
52
+ inline const Trace& get_accepted_trace () {
53
+ return traces[last_accept_cid];
54
+ }
55
+ inline unsigned int get_accepted_length () {
56
+ return last_accept_length;
57
+ }
58
+ inline unsigned int get_draft_length () {
59
+ return draft_length;
60
+ }
61
+ // inline unsigned int get_n_trace() {
62
+ // return traces.size();
63
+ // }
64
+
65
+ unsigned int evalPosterior (const std::vector<std::vector<float >> &logit_scores, const std::vector<unsigned int > &sampled_token_ids) {
66
+ std::vector<unsigned int > accept_lengths;
67
+ int n_candidate = traces.size ();
68
+ unsigned int best_candidate_idx = 0 ;
69
+ unsigned int max_accept_length = 0 ;
70
+ unsigned int best_next_token_id = sampled_token_ids[0 ];
71
+
72
+ int idx_offset = 0 ; // draft token被放到input_ids后的偏移量
73
+ for (int tid = 0 ; tid < n_candidate; tid++) {
74
+ const std::vector<unsigned int > &trace_tokens = traces[tid].trace_tokens ;
75
+ unsigned int trace_length = trace_tokens.size ();
76
+ unsigned int accept_length = 0 ;
77
+ for (int i = 0 ; i < trace_length; i++) {
78
+ int src_idx = i;
79
+ int tgt_idx = (i == 0 )? (0 ) : (idx_offset + i);
80
+ if (trace_tokens[src_idx] == sampled_token_ids[tgt_idx]) {
81
+ accept_length += 1 ;
82
+ } else {
83
+ break ;
84
+ }
85
+ }
86
+ if (accept_length > max_accept_length) {
87
+ max_accept_length = accept_length;
88
+ best_candidate_idx = tid;
89
+ best_next_token_id = sampled_token_ids[idx_offset + accept_length];
90
+ }
91
+ idx_offset += trace_length;
92
+ accept_lengths.push_back (accept_length);
93
+ }
94
+
95
+ this ->last_draft_length = this ->draft_length ;
96
+ this ->last_accept_cid = best_candidate_idx;
97
+ this ->last_accept_length = max_accept_length;
98
+ this ->last_accept_position_ids .clear ();
99
+ for (int i = 0 ; i < max_accept_length; i++) {
100
+ this ->last_accept_position_ids .push_back (this ->trace_position_ids [best_candidate_idx][i]);
101
+ }
102
+ // std::cout << "Accept length: " << max_accept_length << std::endl;
103
+ return best_next_token_id;
104
+ }
105
+
106
+
107
+ unsigned int generate_draft (std::vector<unsigned int > &input_ids, std::vector<unsigned int > &position_ids,
108
+ std::vector<int > &tree_ancestors, unsigned int cur_seq_length) {
109
+ unsigned int draft_len = 0 ;
110
+ this ->trace_position_ids .clear ();
111
+ for (int i = 0 ; i < traces.size (); i++) {
112
+ unsigned int trace_len = traces[i].trace_tokens .size ();
113
+ input_ids.insert (input_ids.end (), traces[i].trace_tokens .begin (), traces[i].trace_tokens .end ());
114
+ tree_ancestors.push_back (0 ); // 每个trace的首节点总是指向start token
115
+ std::vector<unsigned int > pos;
116
+ for (int j = 0 ; j < trace_len; j++) {
117
+ position_ids.push_back (draft_len + j + cur_seq_length);
118
+ pos.push_back (draft_len + j + cur_seq_length);
119
+ if (j > 0 ) {
120
+ tree_ancestors.push_back (draft_len + j);
121
+ }
122
+ }
123
+ this ->trace_position_ids .push_back (pos);
124
+ draft_len += trace_len;
125
+ }
126
+ this ->draft_length = draft_len;
127
+ return draft_len;
128
+ }
129
+
130
+ std::vector<Trace> traces;
131
+ bool is_decoding = false ;
132
+ unsigned int draft_length = 0 ; // draft部分的总长度
133
+ // 记录上一次verify的结果
134
+ unsigned int last_accept_cid = 0 ;
135
+ unsigned int last_accept_length = 0 ;
136
+ unsigned int last_draft_length = 0 ;
137
+ std::vector<unsigned int > last_accept_position_ids;
138
+ std::vector<std::vector<unsigned int >> trace_position_ids;
139
+
140
+ private:
141
+ // std::vector<std::vector<unsigned int>> candidate_token_ids;
142
+ // std::vector<std::vector<unsigned int>> candidate_position_ids;
143
+ // std::map<unsigned int, std::vector<unsigned int>> cid2pids;
144
+ // std::vector<int> tree_ancestors;
145
+
146
+ };
147
+
148
+
149
+ class SuffixAutomaton {
150
+ public:
151
+ struct State {
152
+ std::unordered_map<int , int > next; // 存储字符ID对应的转移状态
153
+ int link = -1 ; // 后缀链接
154
+ int length = 0 ; // 当前状态的长度
155
+ int min_endpos = 0 ; // 当前状态的最小结束位置
156
+ State () = default ;
157
+ State (int link, int length, int min_endpos) : link(link), length(length), min_endpos(min_endpos) {}
158
+ };
159
+
160
+ SuffixAutomaton () {
161
+ states.push_back (State (-1 , 0 , 0 )); // 重新初始化状态
162
+ input_ids.push_back (-1 );
163
+ last = 0 ;
164
+ max_length = 0 ;
165
+ cur_index = 0 ;
166
+ cur_length = 0 ;
167
+ }
168
+
169
+ void reset () {
170
+ states.clear ();
171
+ states.push_back (State (-1 , 0 , 0 ));
172
+ input_ids.clear ();
173
+ input_ids.push_back (-1 );
174
+ last = 0 ;
175
+ max_length = 0 ;
176
+ cur_index = 0 ;
177
+ cur_length = 0 ;
178
+ }
179
+
180
+ void add_tokens (const std::vector<unsigned int >& tokens) {
181
+ for (unsigned int token : tokens) {
182
+ transfer_cur_state (token);
183
+ add_state (token);
184
+ }
185
+ input_ids.insert (input_ids.end (), tokens.begin (), tokens.end ());
186
+ }
187
+
188
+ std::pair<int , int > lookup (int start_token) const {
189
+ int index = cur_index;
190
+ int length = cur_length;
191
+ transfer_state (index , length, start_token);
192
+ return {index , length};
193
+ }
194
+
195
+ int gen_draft (std::vector<unsigned int > &seq, int index, int match_length, unsigned int start_token, int minimum_length = 0 ) {
196
+ int n = std::min (max_predicts, 1 + static_cast <int >(match_length * alpha));
197
+ if (minimum_length > 0 && match_length > 0 ) {
198
+ n = std::max (minimum_length, n);
199
+ }
200
+ int endpos = states[index ].min_endpos ;
201
+ seq.clear ();
202
+ for (int i = endpos + 1 ; i < endpos + n; ++i) {
203
+ if (i >= input_ids.size ()) break ;
204
+ seq.push_back (input_ids[i]);
205
+ }
206
+ return seq.size ();
207
+ }
208
+
209
+ void print () const {
210
+ for (size_t i = 1 ; i < states.size (); ++i) {
211
+ std::cout << " State " << i << " : length = " << states[i].length << " , link = " << states[i].link << " , min_endpos = " << states[i].min_endpos << std::endl;
212
+ for (const auto & [ch, next_state] : states[i].next ) {
213
+ std::cout << " " << char (' a' + ch) << " -> " << next_state << std::endl;
214
+ }
215
+ }
216
+ }
217
+
218
+ private:
219
+ std::vector<State> states;
220
+ int last;
221
+ int max_length;
222
+ int cur_index = 0 ;
223
+ int cur_length = 0 ;
224
+ int max_predicts = 40 ;
225
+ float alpha = 4 .0f ;
226
+ std::vector<int > input_ids;
227
+
228
+ unsigned int expand_state (const State &state) {
229
+ unsigned int new_index = states.size ();
230
+ states.push_back (state);
231
+ return new_index;
232
+ }
233
+
234
+ void add_state (int c) {
235
+ max_length += 1 ;
236
+ int cur = expand_state (State (-1 , max_length, max_length));
237
+ int p = last;
238
+ while (p != -1 && states[p].next .count (c) == 0 ) {
239
+ states[p].next [c] = cur;
240
+ p = states[p].link ;
241
+ }
242
+
243
+ if (p == -1 ) {
244
+ states[cur].link = 0 ;
245
+ } else {
246
+ int q = states[p].next [c];
247
+ if (states[p].length + 1 == states[q].length ) {
248
+ states[cur].link = q;
249
+ } else {
250
+ int clone = states.size ();
251
+ states.push_back (states[q]);
252
+ states[clone].length = states[p].length + 1 ;
253
+ while (p != -1 && states[p].next [c] == q) {
254
+ states[p].next [c] = clone;
255
+ p = states[p].link ;
256
+ }
257
+ states[q].link = states[cur].link = clone;
258
+ }
259
+ }
260
+ last = cur;
261
+ }
262
+
263
+ void transfer_state (int & index, int & length, int token) const {
264
+ while (index != 0 && states[index ].next .count (token) == 0 ) {
265
+ index = states[index ].link ;
266
+ length = states[index ].length ;
267
+ }
268
+ if (states[index ].next .count (token)) {
269
+ index = states[index ].next .at (token);
270
+ length++;
271
+ } else {
272
+ index = length = 0 ;
273
+ }
274
+ }
275
+
276
+ void transfer_cur_state (int token) {
277
+ transfer_state (cur_index, cur_length, token);
278
+ }
279
+
280
+ };
281
+
282
+ } // namespace mllm
283
+
284
+ #endif // ! MLLM_DRAFT_HPP
0 commit comments