diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e9659c33..86e8663f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -79,7 +79,7 @@ func_llm_add_executable(demo_bert) func_llm_add_executable(demo_phonelm) func_llm_add_executable(demo_llama3) func_llm_add_executable(demo_minicpm_moe_mbm) - +func_llm_add_executable(demo_qwen_sd) func_vlm_add_executable(demo_llava) diff --git a/examples/demo_qwen_sd.cpp b/examples/demo_qwen_sd.cpp new file mode 100644 index 00000000..16834be4 --- /dev/null +++ b/examples/demo_qwen_sd.cpp @@ -0,0 +1,65 @@ +/** + * @file demo_qwen_sd.cpp + * @author Zhiyang Chen (zhiyangchen@stu.pku.edu.cn) + * @brief + * @date 2025-3-11 + * + * + */ +#include "cmdline.h" +#include "models/qwen/configuration_qwen.hpp" +#include "models/qwen/modeling_qwen_sd.hpp" +#include "models/qwen/tokenization_qwen.hpp" + +using namespace mllm; + +int main(int argc, char **argv) { + std::iostream::sync_with_stdio(false); + + cmdline::parser cmdParser; + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen_vocab.mllm"); + cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen_merges.txt"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-1.8b-q8_0.mllm"); + cmdParser.add("billion", 'b', "[0.5B | 1.8B | 1.5B |]", false, "1.8B"); + cmdParser.add("limits", 'l', "max KV cache size", false, 400); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string vocab_path = cmdParser.get("vocab"); + string merge_path = cmdParser.get("merge"); + string model_path = cmdParser.get("model"); + string model_billion = cmdParser.get("billion"); + int tokens_limit = cmdParser.get("limits"); + CPUBackend::cpu_threads = cmdParser.get("thread"); + + auto tokenizer = QWenTokenizer(vocab_path, merge_path); + QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE); + auto model = QWenForCausalLM(config); + model.load(model_path); + + vector in_strs = { + "Summarize: Hillary Clinton\u2019s security detail arrived at a suburban Des Moines, Iowa fruit processing company on Tuesday with an added vehicle \u2013 a second Scooby. After her signature oversize black Chevy conversion van dropped her off at Capitol Fruit Company in Norwalk, Iowa, a visually identical GMC van drove up to the building with a nearly identical Secret Service escort vehicle. Both armored vehicles have raised roofs, deep-tinted windows and New York license plates. But while the original van \u2013 the one nicknamed 'Scooby' after the Scooby-Doo cartoon show \u2013 sports a mustard-yellow New York tag, the second has blue and white plates of a different design. Scroll down for video. WHY BUY ONE WHEN YOU CAN HAVE TWO AT TWICE THE PRICE? The first picture of both of Hillary Clinton's Scooby mobiles. One is a GMC and the other is a Chevrolet, but they are mechanically identical. CONVOY: Scooby-one and Scooby-two took up positions in Hillary's motorcade on a freeway near Des Moines", + "Hello, who are you?", + "What can you do?", + "Please introduce Beijing University of Posts and Telecommunications.", + }; + for (int i = 0; i < in_strs.size(); ++i) { + auto input_str = tokenizer.apply_chat_template(in_strs[i]); + auto input_tensor = tokenizer.tokenize(input_str); + std::cout << "[Q] " << in_strs[i] << std::endl; + std::cout << "[A] " << std::flush; + + LlmTextGeneratorOpts opt{ + .max_new_tokens = 50, + .do_sample = false, // TODO 实现投机解码的核采样 + }; + model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool { + auto out_string = tokenizer.detokenize({out_token}); + auto [not_end, output_string] = tokenizer.postprocess(out_string); + if (!not_end) { return false; } + std::cout << output_string << std::flush; + return true; + }); + std::cout << "\n"; + } +} diff --git a/include/OpDefined.hpp b/include/OpDefined.hpp index ed880635..595c7086 100644 --- a/include/OpDefined.hpp +++ b/include/OpDefined.hpp @@ -77,6 +77,10 @@ enum OpType { // new front-end SUPERSILU, HEADLINEAR, + + // for speculative decoding + ROPETREE, + CAUSALTREEMASK, }; static const vector OpNames = { diff --git a/src/Draft.hpp b/src/Draft.hpp new file mode 100644 index 00000000..b1d2a847 --- /dev/null +++ b/src/Draft.hpp @@ -0,0 +1,284 @@ +/** + * @file Draft.hpp + * @author Zhiyang Chen (zhiyangchen@stu.pku.edu.cn) + * @brief + * @date 2025-2-24 + * + * + */ +#pragma once +#ifndef MLLM_DRAFT_HPP +#define MLLM_DRAFT_HPP +#include +#include +#include +#include +#include +#include +#include + +namespace mllm { + + +class TracePool { +public: + struct Trace { + std::vector trace_tokens; + Trace(const std::vector &tokens) : trace_tokens(tokens) {} + }; + + void add_trace(const std::vector &tokens) { + if (tokens.empty()) { + return; + } + traces.push_back(Trace(tokens)); + } + + void clear_trace() { + traces.clear(); + } + + void reset() { + is_decoding = false; + draft_length = 0; + last_accept_cid = 0; + last_accept_length = 0; + last_draft_length = 0; + traces.clear(); + last_accept_position_ids.clear(); + trace_position_ids.clear(); + } + + inline const Trace& get_accepted_trace() { + return traces[last_accept_cid]; + } + inline unsigned int get_accepted_length() { + return last_accept_length; + } + inline unsigned int get_draft_length() { + return draft_length; + } + // inline unsigned int get_n_trace() { + // return traces.size(); + // } + + unsigned int evalPosterior(const std::vector> &logit_scores, const std::vector &sampled_token_ids) { + std::vector accept_lengths; + int n_candidate = traces.size(); + unsigned int best_candidate_idx = 0; + unsigned int max_accept_length = 0; + unsigned int best_next_token_id = sampled_token_ids[0]; + + int idx_offset = 0; // draft token被放到input_ids后的偏移量 + for (int tid = 0; tid < n_candidate; tid++) { + const std::vector &trace_tokens = traces[tid].trace_tokens; + unsigned int trace_length = trace_tokens.size(); + unsigned int accept_length = 0; + for (int i = 0; i < trace_length; i++) { + int src_idx = i; + int tgt_idx = (i == 0)? (0) : (idx_offset + i); + if (trace_tokens[src_idx] == sampled_token_ids[tgt_idx]) { + accept_length += 1; + } else { + break; + } + } + if (accept_length > max_accept_length) { + max_accept_length = accept_length; + best_candidate_idx = tid; + best_next_token_id = sampled_token_ids[idx_offset + accept_length]; + } + idx_offset += trace_length; + accept_lengths.push_back(accept_length); + } + + this->last_draft_length = this->draft_length; + this->last_accept_cid = best_candidate_idx; + this->last_accept_length = max_accept_length; + this->last_accept_position_ids.clear(); + for (int i = 0; i < max_accept_length; i++) { + this->last_accept_position_ids.push_back(this->trace_position_ids[best_candidate_idx][i]); + } + // std::cout << "Accept length: " << max_accept_length << std::endl; + return best_next_token_id; + } + + + unsigned int generate_draft(std::vector &input_ids, std::vector &position_ids, + std::vector &tree_ancestors, unsigned int cur_seq_length) { + unsigned int draft_len = 0; + this->trace_position_ids.clear(); + for (int i = 0; i < traces.size(); i++) { + unsigned int trace_len = traces[i].trace_tokens.size(); + input_ids.insert(input_ids.end(), traces[i].trace_tokens.begin(), traces[i].trace_tokens.end()); + tree_ancestors.push_back(0); // 每个trace的首节点总是指向start token + std::vector pos; + for (int j = 0; j < trace_len; j++) { + position_ids.push_back(draft_len + j + cur_seq_length); + pos.push_back(draft_len + j + cur_seq_length); + if (j > 0) { + tree_ancestors.push_back(draft_len + j); + } + } + this->trace_position_ids.push_back(pos); + draft_len += trace_len; + } + this->draft_length = draft_len; + return draft_len; + } + + std::vector traces; + bool is_decoding = false; + unsigned int draft_length = 0; // draft部分的总长度 + // 记录上一次verify的结果 + unsigned int last_accept_cid = 0; + unsigned int last_accept_length = 0; + unsigned int last_draft_length = 0; + std::vector last_accept_position_ids; + std::vector> trace_position_ids; + +private: + // std::vector> candidate_token_ids; + // std::vector> candidate_position_ids; + // std::map> cid2pids; + // std::vector tree_ancestors; + +}; + + +class SuffixAutomaton { +public: + struct State { + std::unordered_map next; // 存储字符ID对应的转移状态 + int link = -1; // 后缀链接 + int length = 0; // 当前状态的长度 + int min_endpos = 0; // 当前状态的最小结束位置 + State() = default; + State(int link, int length, int min_endpos) : link(link), length(length), min_endpos(min_endpos) {} + }; + + SuffixAutomaton() { + states.push_back(State(-1, 0, 0)); // 重新初始化状态 + input_ids.push_back(-1); + last = 0; + max_length = 0; + cur_index = 0; + cur_length = 0; + } + + void reset() { + states.clear(); + states.push_back(State(-1, 0, 0)); + input_ids.clear(); + input_ids.push_back(-1); + last = 0; + max_length = 0; + cur_index = 0; + cur_length = 0; + } + + void add_tokens(const std::vector& tokens) { + for (unsigned int token : tokens) { + transfer_cur_state(token); + add_state(token); + } + input_ids.insert(input_ids.end(), tokens.begin(), tokens.end()); + } + + std::pair lookup(int start_token) const { + int index = cur_index; + int length = cur_length; + transfer_state(index, length, start_token); + return {index, length}; + } + + int gen_draft(std::vector &seq, int index, int match_length, unsigned int start_token, int minimum_length = 0) { + int n = std::min(max_predicts, 1 + static_cast(match_length * alpha)); + if (minimum_length > 0 && match_length > 0) { + n = std::max(minimum_length, n); + } + int endpos = states[index].min_endpos; + seq.clear(); + for (int i = endpos + 1; i < endpos + n; ++i) { + if (i >= input_ids.size()) break; + seq.push_back(input_ids[i]); + } + return seq.size(); + } + + void print() const { + for (size_t i = 1; i < states.size(); ++i) { + std::cout << "State " << i << ": length = " << states[i].length << ", link = " << states[i].link << ", min_endpos = " << states[i].min_endpos << std::endl; + for (const auto& [ch, next_state] : states[i].next) { + std::cout << " " << char('a' + ch) << " -> " << next_state << std::endl; + } + } + } + +private: + std::vector states; + int last; + int max_length; + int cur_index = 0; + int cur_length = 0; + int max_predicts = 40; + float alpha = 4.0f; + std::vector input_ids; + + unsigned int expand_state(const State &state) { + unsigned int new_index = states.size(); + states.push_back(state); + return new_index; + } + + void add_state(int c) { + max_length += 1; + int cur = expand_state(State(-1, max_length, max_length)); + int p = last; + while (p != -1 && states[p].next.count(c) == 0) { + states[p].next[c] = cur; + p = states[p].link; + } + + if (p == -1) { + states[cur].link = 0; + } else { + int q = states[p].next[c]; + if (states[p].length + 1 == states[q].length) { + states[cur].link = q; + } else { + int clone = states.size(); + states.push_back(states[q]); + states[clone].length = states[p].length + 1; + while (p != -1 && states[p].next[c] == q) { + states[p].next[c] = clone; + p = states[p].link; + } + states[q].link = states[cur].link = clone; + } + } + last = cur; + } + + void transfer_state(int& index, int& length, int token) const { + while (index != 0 && states[index].next.count(token) == 0) { + index = states[index].link; + length = states[index].length; + } + if (states[index].next.count(token)) { + index = states[index].next.at(token); + length++; + } else { + index = length = 0; + } + } + + void transfer_cur_state(int token) { + transfer_state(cur_index, cur_length, token); + } + +}; + +} // namespace mllm + +#endif //! MLLM_DRAFT_HPP \ No newline at end of file diff --git a/src/Generate.cpp b/src/Generate.cpp index 6f6f74b1..1ef46a7c 100644 --- a/src/Generate.cpp +++ b/src/Generate.cpp @@ -20,6 +20,28 @@ unsigned int _LlmTextGenerateGreedySearchMethod::generate(Tensor &t) { return std::max_element(scores.begin(), scores.end()) - scores.begin(); } +unsigned int _LlmTextGenerateGreedySearchMethodForSD::generate_SD(Tensor &t, TracePool &tp) { + if (!tp.is_decoding) { + std::vector scores; + this->_tensor_to_vec(t, scores); // _tensor_to_vec只会取出seq_len最后一个位置的所有值 + return std::max_element(scores.begin(), scores.end()) - scores.begin(); + } + + // 将给定的logits Tensor转换为vector>,其中vector表示每个位置的所有值 + std::vector> scores; + this->_tensor_to_multivec(t, scores); + + // 从vector>中采样出每个位置最大值的位置(正确token id) + std::vector sampled_token_ids; + for (int i = 0; i < scores.size(); ++i) { + unsigned int sampled_token_id = std::max_element(scores[i].begin(), scores[i].end()) - scores[i].begin(); + sampled_token_ids.push_back(sampled_token_id); + } + + const auto best_next_token_id = tp.evalPosterior(scores, sampled_token_ids); + return best_next_token_id; +} + unsigned int _LlmTextGenerateTopkSamplingMethod::generate(Tensor &t) { auto argmax = [](const std::vector &vec) -> unsigned int { return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); @@ -120,4 +142,17 @@ unsigned int _LlmTextGenerateToppSamplingMethod::generate(Tensor &t) { return ret; } +unsigned int _LlmTextGenerateNucleusSamplingMethodForSD::generate_SD(Tensor &t, TracePool &tp) { + // TODO + // 将给定的logits Tensor转换为vector>,其中vector表示每个位置的所有值 + std::vector>> scores; + // int seq_length = t.sequence(); + // std::vector pos; + // for (int i = 0; i < tp.get_draft_length() + 1; i++) { // 需要把非draft的前一轮最后verified的token也放进来算logits + // pos.push_back(i); + // } + this->_tensor_to_multivec_with_idx(t, scores); // _tensor_to_vec只会取出seq_len最后一个位置的所有值 + return 0; +} + } // namespace mllm \ No newline at end of file diff --git a/src/Generate.hpp b/src/Generate.hpp index 4eb669d8..63f79b5c 100644 --- a/src/Generate.hpp +++ b/src/Generate.hpp @@ -17,6 +17,7 @@ #include #include #include "Tensor.hpp" +#include "Draft.hpp" namespace mllm { @@ -47,9 +48,11 @@ enum class LLmTextGeneratorType : int32_t { kTopkSampling, kToppSampling, KLast, + kGreedySearchForSD, }; class _LlmTextGenerateMethod { +protected: bool is_padding = false; int seq_before_padding = 0; int chunk_size = -1; @@ -91,6 +94,54 @@ class _LlmTextGenerateMethod { scores.push_back(std::make_pair(value, i)); } } + + inline void _tensor_to_multivec(Tensor &t, std::vector> &scores) { + assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); + assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); + int _dims = t.dimension(); + int n_seq = t.sequence(); + // TODO: 考虑QNN进行padding + // padding prefill for QNN + // if (is_padding) { + // if (chunk_size > 0) { + // _seq = (seq_before_padding - 1) % chunk_size; + // } else { + // _seq = seq_before_padding - 1; + // } + // } + for (int s = 0; s < n_seq; ++s) { + std::vector values(t.dimension()); + for (int i = 0; i < _dims; ++i) { + auto value = t.dataAt(0, 0, s, i); + values[i] = value; + } + scores.push_back(values); + } + } + + inline void _tensor_to_multivec_with_idx(Tensor &t, std::vector>> &scores) { + assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); + assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); + int _dims = t.dimension(); + int n_seq = t.sequence(); + // TODO: 考虑QNN进行padding + // padding prefill for QNN + // if (is_padding) { + // if (chunk_size > 0) { + // _seq = (seq_before_padding - 1) % chunk_size; + // } else { + // _seq = seq_before_padding - 1; + // } + // } + for (int s = 0; s < n_seq; ++s) { + std::vector> values(t.dimension()); + for (int i = 0; i < _dims; ++i) { + auto value = t.dataAt(0, 0, s, i); + values[i] = std::make_pair(value, i); + } + scores.push_back(values); + } + } }; class _LlmTextGenerateGreedySearchMethod : public _LlmTextGenerateMethod { @@ -100,6 +151,41 @@ class _LlmTextGenerateGreedySearchMethod : public _LlmTextGenerateMethod { unsigned int generate(Tensor &t) override; }; +class _LlmTextGenerateGreedySearchMethodForSD : public _LlmTextGenerateMethod { + public: + _LlmTextGenerateGreedySearchMethodForSD() = default; + ~_LlmTextGenerateGreedySearchMethodForSD() = default; + inline void _tensor_to_vec_of_multiIndices(Tensor &t, std::vector> &scores, std::vector indices) { + assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); + assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); + int _dims = t.dimension(); + // TODO: 考虑QNN进行padding + // padding prefill for QNN + // if (is_padding) { + // if (chunk_size > 0) { + // _seq = (seq_before_padding - 1) % chunk_size; + // } else { + // _seq = seq_before_padding - 1; + // } + // } + for (int idx = 0; idx < indices.size(); ++idx) { + std::vector values(t.dimension()); + int _seq = indices[idx]; + for (int i = 0; i < _dims; ++i) { + auto value = t.dataAt(0, 0, _seq, i); + values[i] = value; + } + scores.push_back(values); + } + } + unsigned int generate(Tensor &t) override { + std::cerr << "Should use generate_SD" << std::endl; + assert(false); + return -1; + }; + unsigned int generate_SD(Tensor &t, TracePool &tp); + }; + class _LlmTextGenerateTopkSamplingMethod : public _LlmTextGenerateMethod { public: ~_LlmTextGenerateTopkSamplingMethod() = default; @@ -128,6 +214,103 @@ class _LlmTextGenerateToppSamplingMethod : public _LlmTextGenerateMethod { float m_temperature = 0.f; }; +class _LlmTextGenerateNucleusSamplingMethodForSD : public _LlmTextGenerateMethod { +public: + _LlmTextGenerateNucleusSamplingMethodForSD(int k, float p, float temp) : samplingConfig(SamplingConfig(temp, p, k)) {} + ~_LlmTextGenerateNucleusSamplingMethodForSD() = default; + + unsigned int generate(Tensor &t) override { + std::cerr << "Should use generate_SD" << std::endl; + assert(false); + return -1; + }; + unsigned int generate_SD(Tensor &t, TracePool &tp); + std::vector evalPosterior(const std::vector> &logit_scores, const std::vector &sampled_token_ids, TracePool &tp); +private: + float temperature = 1.0; + float top_p = 1.0; + int top_k = -1; + struct SamplingConfig { + float temperature = 1.0; + float top_p = 1.0; + int top_k = -1; + SamplingConfig(float _temperature, float _top_p, float _top_k): temperature(_temperature), top_p(_top_p), top_k(_top_k) {} + } samplingConfig; + + void apply_logits_processor(std::vector>& logits_with_indices, const SamplingConfig& config) { + const size_t vocab_size = logits_with_indices.size(); + if (vocab_size == 0) return; + + // 温度调整 + if (config.temperature > 0 && config.temperature != 1.0f) { + const float inv_temp = 1.0f / config.temperature; + for (auto& v : logits_with_indices) v.first *= inv_temp; + } + + // Top-k处理 + if (config.top_k > 0 && config.top_k < (int)vocab_size) { + // 部分排序找出topk + std::partial_sort( + logits_with_indices.begin(), + logits_with_indices.begin() + config.top_k, + logits_with_indices.end() + ); + + // 构建屏蔽掩码 + std::vector mask(vocab_size, false); + for (int i=0; i 0.0f && config.top_p < 1.0f) { + // 计算softmax + std::vector probs(vocab_size); + std::pair max_logit_with_index = *std::max_element(logits_with_indices.begin(), logits_with_indices.end(), + [](std::pair a, std::pair b) { return a.first > b.first; }); + float max_logit = max_logit_with_index.first; + float sum_exp = 0.0f; + for (size_t i=0; i> sorted_probs(vocab_size); + for (size_t i=0; i a, std::pair b) { return a.first > b.first; }); + + // 计算累积概率 + float cumulative = 0.0f; + size_t cutoff = 0; + for (; cutoff < vocab_size; ++cutoff) { + cumulative += sorted_probs[cutoff].first; + if (cumulative > config.top_p) break; + } + cutoff = std::min(cutoff+1, vocab_size-1); + + // 构建有效集合 + std::vector valid(vocab_size, false); + for (size_t i=0; igenerate(t); } + inline unsigned int generate_SD(Tensor &t, TracePool &tp) { + // 检查m_method_class类型是_LlmTextGenerateGreedySearchMethodForSD然后调用,否则报错 + if (m_type != LLmTextGeneratorType::kGreedySearchForSD) { + std::cerr << "Should use generate_SD in _LlmTextGenerateGreedySearchMethodForSD only" << std::endl; + assert(false); + return -1; + } + return dynamic_cast<_LlmTextGenerateGreedySearchMethodForSD*>(m_method_class)->generate_SD(t, tp); + }; + inline unsigned int generate(Tensor &t, const LlmTextGeneratorOpts &opt) { if (opt.is_padding) { m_method_class->setPadding(opt.is_padding, opt.seq_before_padding, opt.chunk_size); diff --git a/src/Layer.hpp b/src/Layer.hpp index 3ad1f9c8..bcb0ddca 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -559,6 +559,23 @@ class Causalmask final : public Layer { } }; +class CausalTreeMask final : public Layer { + public: + CausalTreeMask() = default; + explicit CausalTreeMask(std::string name) { + init(std::move(name), OpType::CAUSALTREEMASK); + } + Tensor &operator()(Tensor &input) { + auto ts = run({input}, 1); + return ts[0].get(); + } + Tensor &operator()(Tensor &input0, int kvcache_seq, Tensor &tree_ancestor) { + auto kvcache_seq_tensor = Tensor(kvcache_seq, backend_); + auto ts = run({input0, kvcache_seq_tensor, tree_ancestor}, 1); + return ts[0].get(); + } + }; + class SlidingWindowMask final : public Layer { public: explicit SlidingWindowMask(int window_size, std::string name) { @@ -642,6 +659,76 @@ class RoPE final : public Layer { } }; +class RoPETree final : public Layer { + public: + RoPETree() = default; + + explicit RoPETree(int pose_type, const RoPEConfig & config, std::string name) { + param_["pose_type"] = pose_type; + auto it_rope_theta = config.find("rope_theta"); + if (it_rope_theta != config.end()) { + param_["rope_theta"] = std::any_cast(it_rope_theta->second); + } + + auto it_max_position_embeddings = config.find("max_position_embeddings"); + if (it_max_position_embeddings != config.end()) { + param_["max_position_embeddings"] = std::any_cast(it_max_position_embeddings->second); + } + + auto it_partial_rotary_factor = config.find("partial_rotary_factor"); + if (it_partial_rotary_factor != config.end()) { + param_["partial_rotary_factor"] = std::any_cast(it_partial_rotary_factor->second); + } + + if (config.find("rope_scaling") != config.end()) { + auto rope_scaling = std::any_cast>(config.at("rope_scaling")); + auto it = rope_scaling.find("rope_type"); + if (it != rope_scaling.end()) { + string rope_type = std::any_cast(it->second); + if (rope_type == "default") { + param_["rope_type"] = DEFAULT; + } else if (rope_type == "llama3") { + param_["rope_type"] = LLAMA3; + param_["factor"] = std::any_cast(rope_scaling.at("factor")); + param_["high_freq_factor"] = std::any_cast(rope_scaling.at("high_freq_factor")); + param_["low_freq_factor"] = std::any_cast(rope_scaling.at("low_freq_factor")); + param_["original_max_position_embeddings"] = std::any_cast(rope_scaling.at("original_max_position_embeddings")); + } else { + std::cout << "[TODO]rope type " << rope_type << " not support!!!!" << std::endl; + } + } + } + + init(std::move(name), OpType::ROPETREE); + } + + explicit RoPETree(int pose_type, std::string name) { + param_["pose_type"] = pose_type; + init(std::move(name), OpType::ROPETREE); + } + explicit RoPETree(int pose_type, float rope_theta, int max_position_embeddings, std::string name) { + param_["pose_type"] = pose_type; + param_["rope_theta"] = rope_theta; + param_["max_position_embeddings"] = max_position_embeddings; + init(std::move(name), OpType::ROPETREE); + } + explicit RoPETree(int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, std::string name) { + param_["pose_type"] = pose_type; + param_["rope_theta"] = rope_theta; + param_["max_position_embeddings"] = max_position_embeddings; + param_["partial_rotary_factor"] = partial_rotary_factor; + init(std::move(name), OpType::ROPETREE); + } + Tensor &operator()(Tensor &input, Tensor &tree_ancestor) { + auto ts = run({input, tree_ancestor}, 1); + return ts[0].get(); + } + void clearCache() { + return op_->clearCache(); + } + }; + + class IRoPE final : public Layer { public: IRoPE() = default; diff --git a/src/backends/cpu/CPUBackend.cpp b/src/backends/cpu/CPUBackend.cpp index d5dd6610..73d00e75 100644 --- a/src/backends/cpu/CPUBackend.cpp +++ b/src/backends/cpu/CPUBackend.cpp @@ -15,10 +15,12 @@ #include "op/CPUView.hpp" #include "op/CPUAdd.hpp" #include "op/CPUCausalMask.hpp" +#include "op/CPUCausalTreeMask.hpp" #include "op/CPUSlidingWindowMask.hpp" #include "op/CPUMatmul.hpp" #include "op/CPURMSNorm.hpp" #include "op/CPURoPE.hpp" +#include "op/CPURoPETree.hpp" #include "op/CPUScale.hpp" #include "op/CPUSiLU.hpp" #include "op/CPUSoftMax.hpp" @@ -130,10 +132,12 @@ void CPUBackend::registerOps() { addCreator(PARAMETER, (CPUBackend::Creator *)(new CPUParameterCreator())); addCreator(ADD, (CPUBackend::Creator *)(new CPUAddCreator())); addCreator(CAUSALMASK, (CPUBackend::Creator *)(new CPUCausalMaskCreator())); + addCreator(CAUSALTREEMASK, (CPUBackend::Creator *)(new CPUCausalTreeMaskCreator())); addCreator(SLIDINGWINDOWMASK, (CPUBackend::Creator *)(new CPUSlidingWindowMaskCreator())); addCreator(MATMUL, (CPUBackend::Creator *)(new CPUMatmulCreator())); addCreator(RMSNORM, (CPUBackend::Creator *)(new CPURMSNormCreator())); addCreator(ROPE, (CPUBackend::Creator *)(new CPURoPECreator())); + addCreator(ROPETREE, (CPUBackend::Creator *)(new CPURoPETreeCreator())); addCreator(SCALE, (CPUBackend::Creator *)(new CPUScaleCreator())); addCreator(SILU, (CPUBackend::Creator *)(new CPUSiLUCreator())); addCreator(SOFTMAX, (CPUBackend::Creator *)(new CPUSoftMaxCreator())); diff --git a/src/backends/cpu/CPUBackend.hpp b/src/backends/cpu/CPUBackend.hpp index 3a4ea5b1..db928bae 100644 --- a/src/backends/cpu/CPUBackend.hpp +++ b/src/backends/cpu/CPUBackend.hpp @@ -64,6 +64,28 @@ class CPUBackend final : public Backend { return execution_type; } // #endif + + + // #ifdef USE_SD + void setLastDraftLength(unsigned int draft_length) { + last_draft_length = draft_length; + } + void setLastVerifiedPositionIds(const std::vector &verified_position_ids) { + last_verified_position_ids = verified_position_ids; + } + void setUsingDraft(bool _usingDraft) { + this->usingDraft = _usingDraft; + } + unsigned int getLastDraftLength() { + return last_draft_length; + } + std::vector getLastVerifiedPositionIds() { + return last_verified_position_ids; + } + bool isUsingDraft() { + return usingDraft; + } + // #endif private: std::map map_creator_; std::map map_function_; @@ -77,6 +99,13 @@ class CPUBackend final : public Backend { bool isSwitchingStage = false; ExecutionType execution_type = PROMPT; // #endif + + // #ifdef USE_SD + bool usingDraft = false; + std::vector last_verified_position_ids; + unsigned int last_draft_length = 0; + // #endif + }; } // namespace mllm diff --git a/src/backends/cpu/op/CPUCausalTreeMask.cpp b/src/backends/cpu/op/CPUCausalTreeMask.cpp new file mode 100644 index 00000000..e5594c8e --- /dev/null +++ b/src/backends/cpu/op/CPUCausalTreeMask.cpp @@ -0,0 +1,112 @@ + +#include "CPUCausalTreeMask.hpp" +#include + +namespace mllm { + + +CPUCausalTreeMask::CPUCausalTreeMask(Backend *bn, string opName, int threadCount) : thread_count(threadCount), + Op(bn, opName) { +} + +ErrorCode CPUCausalTreeMask::reshape(vector> inputs, vector> outputs) { + //std::cout << "CPUMask reshape" << std::endl; + // assert(inputs.size() == 1); + assert(outputs.size() == 1); + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); + return Op::reshape(inputs, outputs); +} + + +/* +* inputs: hidden_values, tree_ancestor +* e.g. hidden_values [6 x 6], tree_ancestor [-1, 0, 1, 1, 3], draft_size = 5 +* Target Mask (usual): +* TFFFFF +* TTFFFF +* TTTFFF +* TTTTFF +* TTTTTF +* TTTTTT + +* Target Mask (tree): +* cR0113 +* TFFFFF +* TTFFFF +* TTTFFF +* TTTTFF +* TTTFTF +* TTTFTT +* 第s行(s>=c)和第tree_ancestor_d,即t[s-c]+c行的mask一样,并在[s, s]位置为T +*/ +ErrorCode CPUCausalTreeMask::execute(vector> inputs, vector> outputs) { + if (inputs[0]->sequence() > 1) { + int batch_size = inputs[0]->batch(); + int head_num = inputs[0]->head(); + int sequence = inputs[0]->sequence(); + int dimension = inputs[0]->dimension(); + int old_dim = 0; + + auto& tree_ancestor = inputs[2]; + int draft_size = tree_ancestor->sequence() - 1; + int context_size = sequence - draft_size; + if (inputs.size()>1) { + old_dim = (int)inputs[1]->dataAt(0,0,0,0)-sequence; + }else{ +#ifndef LLAMAFILE_SGEMM + old_dim = dimension - sequence; +#endif + } +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int n = 0; n < batch_size; ++n) { + for (int h = 0; h < head_num; ++h) { + for (int s = 0; s < sequence; ++s) { + // 非draft的部分(之前的上下文)按正常的mask计算 + if (s < context_size) { + for (int d = 0; d < inputs[0]->dimension(); ++d) { + if (d > s + old_dim) { + outputs[0]->setDataAt({n, h, s, d}, -INFINITY); + } + else { + outputs[0]->setDataAt({n, h, s, d}, inputs[0]->dataAt(n, h, s, d)); + } + } + } + // draft的部分(树的部分)按树的mask计算 + else { + int tree_ancestor_s = tree_ancestor->dataAt({0, 0, s - context_size + 1, 0}); // + context_size; + for (int d = 0; d < inputs[0]->dimension(); ++d) { + if (d == s) { + outputs[0]->setDataAt({n, h, s, d}, inputs[0]->dataAt(n, h, s, d)); + } else { + if (outputs[0]->dataAt({n, h, tree_ancestor_s, d}) == -INFINITY) { + outputs[0]->setDataAt({n, h, s, d}, -INFINITY); + } else { + outputs[0]->setDataAt({n, h, s, d}, inputs[0]->dataAt(n, h, s, d)); + } + } + } + } + } + } + } + } + else{ + outputs[0]->copyFrom(inputs[0]); + } + return Op::execute(inputs, outputs); +} + + +ErrorCode CPUCausalTreeMask::setUp(vector> inputs, vector> outputs) { + // assert(inputs.size() == 1); + assert(outputs.size() == 1); + if(inputs[0]->masterTensor() == nullptr) { + inputs[0]->free(); // TODO remove + } + outputs[0]->setDtype(activation_dtype()); + outputs[0]->alloc(); + inputs[0]->shallowCopyFrom(outputs[0].get(), false); + return MLLM_NO_ERROR; +} +} // namespace mllm diff --git a/src/backends/cpu/op/CPUCausalTreeMask.hpp b/src/backends/cpu/op/CPUCausalTreeMask.hpp new file mode 100644 index 00000000..574a7483 --- /dev/null +++ b/src/backends/cpu/op/CPUCausalTreeMask.hpp @@ -0,0 +1,29 @@ +#ifndef MLLM_CPUCAUSULTREEMASK_H +#define MLLM_CPUCAUSULTREEMASK_H + +#include "Op.hpp" +#include "../CPUBackend.hpp" + +namespace mllm { + +class CPUCausalTreeMask final : public Op { +public: + CPUCausalTreeMask(Backend *bn, string opName, int threadCount); + virtual ~CPUCausalTreeMask() = default; + virtual ErrorCode reshape(vector> inputs, vector> outputs) override; + virtual ErrorCode execute(vector> inputs, vector> outputs) override; + virtual ErrorCode setUp(vector> inputs, vector> outputs) override; + +private: + int thread_count = 4; +}; + +class CPUCausalTreeMaskCreator : public CPUBackend::Creator { +public: + virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { + return new CPUCausalTreeMask(bn, name, threadCount); + } +}; +} // namespace mllm + +#endif // MLLM_CPUCAUSULTREEMASK_H \ No newline at end of file diff --git a/src/backends/cpu/op/CPUKVCache.cpp b/src/backends/cpu/op/CPUKVCache.cpp index 638bc453..830d3a78 100644 --- a/src/backends/cpu/op/CPUKVCache.cpp +++ b/src/backends/cpu/op/CPUKVCache.cpp @@ -67,6 +67,15 @@ ErrorCode CPUKVCache::reshape(vector> inputs, }; cache_seq_len_ = 0; } + + // for sd + auto cpuBackend = dynamic_cast(backend_); + if (cpuBackend->isUsingDraft()) { + unsigned int last_draft_length = cpuBackend->getLastDraftLength(); + const std::vector &last_verified_position_ids = cpuBackend->getLastVerifiedPositionIds(); + cache_seq_len_ = cache_seq_len_ - (last_draft_length) + last_verified_position_ids.size(); + } + int sequence = inputs[0]->sequence() + cache_seq_len_; #ifdef LLAMAFILE_SGEMM if (!for_xnn_ && sequence % n_pack != 0) sequence = ((sequence + (n_pack - 1)) / n_pack) * n_pack; @@ -91,6 +100,15 @@ ErrorCode CPUKVCache::load(AbstructLoader &loader) { ErrorCode CPUKVCache::execute(vector> inputs, vector> outputs) { + // for sd + auto cpuBackend = dynamic_cast(backend_); + if (cpuBackend->isUsingDraft()) { + const std::vector &last_verified_position_ids = cpuBackend->getLastVerifiedPositionIds(); + if (!last_verified_position_ids.empty()) { + this->updateVerifiedKVCache(last_verified_position_ids); + } + } + int cache_seq_len_old = cache_seq_len_; cache_seq_len_ += inputs[0]->sequence(); if (n_rep_ > 1) { @@ -193,4 +211,125 @@ ErrorCode CPUKVCache::setUp(vector> inputs, vector &verified_position_ids) { + if (cache_.ctype() == BSHD) { + unsigned int dest_pid = cache_seq_len_ - verified_position_ids.size(); + for (unsigned int src_pid: verified_position_ids) { + if (src_pid == dest_pid) { + dest_pid += 1; + continue; + } +// #pragma omp parallel for collapse(1) num_threads(thread_count) + for (int b = 0; b < cache_.batch(); ++b) { + if (cache_.dtype() == MLLM_TYPE_F32) { + auto src_ptr = cache_.ptrAt(b, 0, src_pid, 0); + auto dest_ptr = cache_.ptrAt(b, 0, dest_pid, 0); + int copy_size = cache_.dimension() * cache_.head(); + memcpy(dest_ptr, src_ptr, copy_size * sizeof(float)); + } else if (cache_.dtype() == MLLM_TYPE_F16) { + auto src_ptr = cache_.ptrAt(b, 0, src_pid, 0); + auto dest_ptr = cache_.ptrAt(b, 0, dest_pid, 0); + int copy_size = cache_.dimension() * cache_.head(); + memcpy(dest_ptr, src_ptr, copy_size * sizeof(mllm_fp16_t)); + } else if (cache_.dtype() == MLLM_TYPE_Q8_0) { + // TODO: Q8 Check + auto src_ptr = + (char *)cache_.rawHostPtr() + cache_.offset(b, 0, src_pid, 0) * sizeof(block_q8_0) / QK8_0; + auto dest_ptr = (char *)cache_.rawHostPtr() + cache_.offset(b, 0, dest_pid, 0) * sizeof(block_q8_0) / QK8_0; + int copy_size = cache_.dimension() * cache_.head(); + memcpy(dest_ptr, src_ptr, copy_size * sizeof(block_q8_0) / QK8_0); + } + } + dest_pid += 1; + } + } else if (cache_.ctype() == BHDS) { + unsigned int dest_pid = cache_seq_len_ - verified_position_ids.size(); + for (unsigned int src_pid: verified_position_ids) { + if (src_pid == dest_pid) { + dest_pid += 1; + continue; + } +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int b = 0; b < cache_.batch(); ++b) { + for (int h = 0; h < cache_.head(); ++h) { + for (int d = 0; d < cache_.dimension(); ++d) { + if (cache_.dtype() == MLLM_TYPE_F32) { + auto src_data = cache_.dataAt(b, h, src_pid, d); + cache_.setDataAt(b, h, dest_pid, d, src_data); + } else if (cache_.dtype() == MLLM_TYPE_F16) { + auto src_data = cache_.dataAt(b, h, src_pid, d); + cache_.setDataAt(b, h, dest_pid, d, src_data); + } else if (cache_.dtype() == MLLM_TYPE_Q8_0) { + // TODO: Q8 Check 不知道q8能不能直接setDataAt + // auto src_data = cache_.dataAt(b, h, src_pid, d); + // cache_.setDataAt(b, h, dest_pid, d, src_data); + auto src_ptr = + (char *)cache_.rawHostPtr() + cache_.offset(b, h, src_pid, d) * sizeof(block_q8_0) / QK8_0; + auto dest_ptr = (char *)cache_.rawHostPtr() + cache_.offset(b, h, dest_pid, d) * sizeof(block_q8_0) / QK8_0; + int copy_size = 1; + memcpy(dest_ptr, src_ptr, copy_size * sizeof(block_q8_0) / QK8_0); + } + } + } + } + dest_pid += 1; + } + } else { + std::cout << "ERROR Ctype in KVCcache;" << std::endl; + } + + // clear kv cache + // if (cache_seq_len_ < cache_seq_len_old) { + // if (n_rep_ > 1) { + // if (cache_.ctype() == BSHD) { + // for (int b = 0; b < cache_.batch(); ++b) { + // for (int h = cache_.head() - 1; h >= 0; --h) { + // // #pragma omp parallel for collapse(2) num_threads(thread_count) + // for (int seq = cache_seq_len_; seq < cache_seq_len_old; ++seq) { + // for (int i_rep = 0; i_rep < n_rep_; ++i_rep) { + // auto cache_head = h * n_rep_ + i_rep; + // if (cache_.dtype() == MLLM_TYPE_F32) { + // auto dest_ptr = cache_.ptrAt(b, cache_head, seq, 0); + // int copy_size = cache_.dimension(); + // memset(dest_ptr, 0, copy_size * sizeof(float)); + // } else if (cache_.dtype() == MLLM_TYPE_F16) { + // auto dest_ptr = cache_.ptrAt(b, cache_head, seq, 0); + // int copy_size = cache_.dimension(); + // memset(dest_ptr, 0, copy_size * sizeof(mllm_fp16_t)); + // } + // } + // } + // } + // } + // } else if (cache_.ctype() == BHDS) { + // for (int b = 0; b < cache_.batch(); ++b) { + // for (int h = cache_.head() - 1; h >= 0; --h) { + // // #pragma omp parallel for collapse(2) num_threads(thread_count) + // for (int d = 0; d < cache_.dimension(); ++d) { + // for (int i_rep = 0; i_rep < n_rep_; ++i_rep) { + // auto cache_head = h * n_rep_ + i_rep; + // if (cache_.dtype() == MLLM_TYPE_F32) { + // auto dest_ptr = + // cache_.ptrAt(b, cache_head, cache_seq_len_, d); + // int copy_size = cache_seq_len_old - cache_seq_len_; + // memset(dest_ptr, 0, copy_size * sizeof(float)); + // } else if (cache_.dtype() == MLLM_TYPE_F16) { + // auto dest_ptr = + // cache_.ptrAt(b, cache_head, cache_seq_len_, d); + // int copy_size = cache_seq_len_old - cache_seq_len_; + // memset(dest_ptr, 0, copy_size * sizeof(mllm_fp16_t)); + // } + // } + // } + // } + // } + // } else { + // std::cout << "ERROR Ctype in KVCcache;" << std::endl; + // } + // } + // } + return MLLM_NO_ERROR; +} + } // namespace mllm \ No newline at end of file diff --git a/src/backends/cpu/op/CPUKVCache.hpp b/src/backends/cpu/op/CPUKVCache.hpp index 12392b30..7f78cdf1 100644 --- a/src/backends/cpu/op/CPUKVCache.hpp +++ b/src/backends/cpu/op/CPUKVCache.hpp @@ -31,6 +31,8 @@ class CPUKVCache final : public Op { for_xnn_ = for_xnn; } + ErrorCode updateVerifiedKVCache(const std::vector &verified_position_ids); + private: int thread_count = 4; diff --git a/src/backends/cpu/op/CPURoPE.cpp b/src/backends/cpu/op/CPURoPE.cpp index ff9d1a6e..0548771c 100644 --- a/src/backends/cpu/op/CPURoPE.cpp +++ b/src/backends/cpu/op/CPURoPE.cpp @@ -16,7 +16,6 @@ vector> CPURoPE::cos_; int CPURoPE::global_pose_type_ = -1; int CPURoPE::ishape_old; -typedef float (*mllm_rope_init_func)(const OpParam &, std::vector&); float _default_init_rope(const OpParam& config, vector& theta) { auto base = config.at("base"); // theta_i = base^-(2i/dim) = 1 / base^(2i/dim) i from 0 to (dim/2 - 1) @@ -70,10 +69,6 @@ float _compute_llama3_theta(const OpParam& config, vector& theta) { return 1.0; } -static const unordered_map rope_init_func_map = { - {DEFAULT, _default_init_rope}, - {LLAMA3, _compute_llama3_theta}, -}; void sinusoidal_position_embedding_llama(int seq_len, int output_dim, const vector& theta, vector> &sin, vector> &cos, float attention_scaling = 1.0) { diff --git a/src/backends/cpu/op/CPURoPE.hpp b/src/backends/cpu/op/CPURoPE.hpp index bce12eac..74138fb4 100644 --- a/src/backends/cpu/op/CPURoPE.hpp +++ b/src/backends/cpu/op/CPURoPE.hpp @@ -6,6 +6,21 @@ namespace mllm { +typedef float (*mllm_rope_init_func)(const OpParam &, std::vector&); + +float _default_init_rope(const OpParam& config, vector& theta); +float _compute_llama3_theta(const OpParam& config, vector& theta); + +static const unordered_map rope_init_func_map = { + {DEFAULT, _default_init_rope}, + {LLAMA3, _compute_llama3_theta}, +}; + +void sinusoidal_position_embedding_llama(int seq_len, int output_dim, const vector& theta, + vector> &sin, vector> &cos, float attention_scaling); +void sinusoidal_position_embedding_huggingface(int seq_len, int output_dim, const vector& theta, + vector> &sin, vector> &cos, float attention_scaling); + class CPURoPE final : public Op { public: CPURoPE(Backend *bn, string opName, int pose_type, int threadCount); diff --git a/src/backends/cpu/op/CPURoPETree.cpp b/src/backends/cpu/op/CPURoPETree.cpp new file mode 100644 index 00000000..856f410e --- /dev/null +++ b/src/backends/cpu/op/CPURoPETree.cpp @@ -0,0 +1,458 @@ + +#include "CPURoPE.hpp" +#include "CPURoPETree.hpp" +#include "Timing.hpp" +#include "Types.hpp" +#include +#include +#include +#include "backends/cpu/quantize/QuantizeQ8.hpp" + +namespace mllm { + +vector CPURoPETree::theta_; + +vector> CPURoPETree::sin_; +vector> CPURoPETree::cos_; +int CPURoPETree::global_pose_type_ = -1; +int CPURoPETree::ishape_old; + + +CPURoPETree::CPURoPETree(Backend *bn, string opName, int pose_type, int threadCount) : + thread_count(threadCount), + Op(bn, opName) { + pose_type_ = pose_type; +} + +CPURoPETree::CPURoPETree(Backend *bn, string opName, int pose_type, float rope_theta, int max_position_embeddings, int threadCount) : + thread_count(threadCount), + Op(bn, opName) { + pose_type_ = pose_type; + rope_theta_ = rope_theta; + pos_max_ = max_position_embeddings; +} + +CPURoPETree::CPURoPETree(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, int threadCount) : + thread_count(threadCount), + Op(bn, opName) { + pose_type_ = pose_type; + rope_theta_ = rope_theta; + partial_rotary_factor_ = partial_rotary_factor; + pos_max_ = max_position_embeddings; +} + +CPURoPETree::CPURoPETree(Backend *bn, string opName, OpParam& config, int threadCount) : + thread_count(threadCount), + Op(bn,opName) { + config_ = config; + pose_type_ = config.at("pose_type"); + auto it = config.find("rope_theta"); + if (it != config.end()) { + rope_theta_ = it->second; + } + it = config.find("partial_rotary_factor"); + if (it != config.end()) { + partial_rotary_factor_ = it->second; + } + it = config.find("max_position_embeddings"); + if (it != config.end()) { + pos_max_ = it->second; + } + rope_type = (RoPEThetaType)config.at("rope_type"); +} + +ErrorCode CPURoPETree::reshape(vector> inputs, vector> outputs) { +// std::cout << name() << " CPURoPETree reshape" << std::endl; + assert(inputs.size() == 2); + assert(outputs.size() == 1); + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); + ishape = inputs[0]->dimension() * partial_rotary_factor_; + // pos_max_ = 16384; + + if (sin_.empty() || ishape_old < ishape || global_pose_type_ != pose_type_) { + auto calc_theta = rope_init_func_map.at(rope_type); + auto config = config_; + config["base"] = (float)rope_theta_; + config["dim"] = ishape; + float attention_scaling = calc_theta(config, theta_); + + global_pose_type_ = pose_type_; + ishape_old = ishape; + if (pose_type_ == LLAMAROPE) { + sinusoidal_position_embedding_llama(pos_max_, ishape, theta_, sin_, cos_, attention_scaling); + } else if (pose_type_ == PERSIMMONROPE) { + sinusoidal_position_embedding_huggingface(pos_max_, ishape / 2, theta_, sin_, cos_, attention_scaling); + } else if (pose_type_ == HFHUBROPE || pose_type_ == MLAROPE) { + sinusoidal_position_embedding_huggingface(pos_max_, ishape, theta_, sin_, cos_, attention_scaling); + } else { + } + } +#ifdef USE_QNN + auto cpuBackend = dynamic_cast(backend_); + if (cpuBackend->isStageSwitching()) { + h_cnt_ = cpuBackend->getCurSequenceLength(); + } +#else + auto cpuBackend = dynamic_cast(backend_); +#endif + + // for sd + if (cpuBackend->isUsingDraft()) { + unsigned int last_draft_length = cpuBackend->getLastDraftLength(); + const std::vector &last_verified_position_ids = cpuBackend->getLastVerifiedPositionIds(); + h_cnt_ = h_cnt_ - (last_draft_length) + last_verified_position_ids.size(); + if (h_cnt_ < 0) { + h_cnt_ = 0; + } + } + return Op::reshape(inputs, outputs); +} + +void CPURoPETree::rope_llama(shared_ptr input, shared_ptr output) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; d += 2) { + float in_value = input->dataAt(n, h, s, d); + float in_value_2 = input->dataAt(n, h, s, d + 1); + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + 1, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + 1, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } +} + +void CPURoPETree::rope_hf(shared_ptr input, shared_ptr output, shared_ptr tree_ancestor) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; + int half = (int)(partial_dimension / 2); + assert(partial_dimension % 2 == 0); + std::vector position_ids(tree_ancestor->sequence()); + if (output->ctype() == BSHD) { + if (input->dtype() == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + int pos = 0; + if (s == 0 || tree_ancestor->sequence() == 1) { + pos = s + h_cnt_; + if (tree_ancestor->sequence() > 1) { + position_ids[s] = pos; + } + } else { + auto ancestor_idx = tree_ancestor->dataAt(0, 0, s, 0); + pos = position_ids[ancestor_idx] + 1; + position_ids[s] = pos; + } + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = static_cast(v[0]); + float in_value_2 = static_cast(v[half]); + float sin_value = sin_[pos][d]; + float cos_value = cos_[pos][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + + } else { + if (out_dtype == MLLM_TYPE_F32) { +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + int pos = 0; + if (s == 0 || tree_ancestor->sequence() == 1) { + pos = s + h_cnt_; + if (tree_ancestor->sequence() > 1) { + position_ids[s] = pos; + } + } else { + auto ancestor_idx = tree_ancestor->dataAt(0, 0, s, 0); + pos = position_ids[ancestor_idx] + 1; + position_ids[s] = pos; + } + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[pos][d]; + float cos_value = cos_[pos][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = value; + o[half] = value2; + } + } + } + } + } else if (out_dtype == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + int pos = 0; + if (s == 0 || tree_ancestor->sequence() == 1) { + pos = s + h_cnt_; + if (tree_ancestor->sequence() > 1) { + position_ids[s] = pos; + } + } else { + auto ancestor_idx = tree_ancestor->dataAt(0, 0, s, 0); + pos = position_ids[ancestor_idx] + 1; + position_ids[s] = pos; + } + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[pos][d]; + float cos_value = cos_[pos][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + } + } + return; + } +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + int pos = 0; + if (s == 0 || tree_ancestor->sequence() == 1) { + pos = s + h_cnt_; + if (tree_ancestor->sequence() > 1) { + position_ids[s] = pos; + } + } else { + auto ancestor_idx = tree_ancestor->dataAt(0, 0, s, 0); + pos = position_ids[ancestor_idx] + 1; + position_ids[s] = pos; + } + for (int d = 0; d < partial_dimension / 2; ++d) { + if (input->dtype() == MLLM_TYPE_F16) { + float in_value = static_cast(input->dataAt(n, h, s, d)); + float in_value_2 = static_cast(input->dataAt(n, h, s, d + partial_dimension / 2)); + float sin_value = sin_[pos][d]; + float cos_value = cos_[pos][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + + } else { + float in_value = input->dataAt(n, h, s, d); + float in_value_2 = input->dataAt(n, h, s, d + partial_dimension / 2); + float sin_value = sin_[pos][d]; + float cos_value = cos_[pos][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } + } +} +void CPURoPETree::rope_permission(shared_ptr input, shared_ptr output) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; ++d) { + float in_value = input->dataAt(n, h, s, d); + float in_value_2; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + if (d < partial_dimension / 4) { + in_value_2 = -input->dataAt(n, h, s, d + partial_dimension / 4); + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } else if (d < (partial_dimension / 2)) { + in_value_2 = input->dataAt(n, h, s, d - partial_dimension / 4); + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } else { + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, in_value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(in_value)); + } + } + } + } + } + } +} +void CPURoPETree::rope_mla(shared_ptr input, shared_ptr output) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; ++d) { + int half_dim = input->dimension() / 2; + float in_value = input->dataAt(n, h, s, d); + if (d < half_dim) { + in_value = input->dataAt(n, h, s, d * 2); + } else { + in_value = input->dataAt(n, h, s, 2 * (d - half_dim) + 1); + } + float in_value_2; + if (d < half_dim) { + in_value_2 = -input->dataAt(n, h, s, 2 * d + 1); + } else { + in_value_2 = input->dataAt(n, h, s, 2 * (d - half_dim)); + } + // no change + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } + } + } + } +} +// TODO: Q8_0 KVCache can not use!! +// TODO: 理论上Q8这里会调用doExecute(inputs, {tmp_out}),所以我直接改了doExecute应该没有问题 +ErrorCode CPURoPETree::execute(vector> inputs, vector> outputs) { + if (outputs[0]->dtype() == MLLM_TYPE_Q8_0) { + auto tmp_out = std::make_shared(outputs[0]->backend()); + // tmp_out->setBackend(outputs[0]->backend()); + auto b = outputs[0]->batch(); + auto h = outputs[0]->head(); + auto d = outputs[0]->dimension(); + auto s = outputs[0]->sequence(); + tmp_out->chls() = outputs[0]->chls(); + tmp_out->setCtype(outputs[0]->ctype()); + tmp_out->reshape(b, h, s, d); + tmp_out->setDtype(MLLM_TYPE_F32); + tmp_out->alloc(); + doExecute(inputs, {tmp_out}); +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int b = 0; b < tmp_out->batch(); b++) { + for (int h = 0; h < tmp_out->head(); h++) { + for (int s = 0; s < tmp_out->sequence(); s++) { + quantize_row_q8_0(tmp_out->hostPtr() + tmp_out->offset(b, h, s, 0), + (char *)outputs[0]->rawHostPtr() + + outputs[0]->offset(b, h, s, 0) * sizeof(block_q8_0) / QK8_0, + tmp_out->dimension()); + } + } + } + return MLLM_NO_ERROR; + } else { + return doExecute(inputs, outputs); + } +} +ErrorCode CPURoPETree::doExecute(vector> inputs, vector> outputs) { + auto &input = inputs[0]; + auto &output = outputs[0]; + auto &tree_ancestor = inputs[1]; + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; + // auto start_t = mllm_time_us(); + if (pose_type_ == LLAMAROPE) { + throw std::runtime_error("LLAMAROPE is not implemented for RoPETree"); + rope_llama(input, output); + } else if (pose_type_ == HFHUBROPE) { + rope_hf(input, output, tree_ancestor); + } else if (pose_type_ == PERSIMMONROPE) { + throw std::runtime_error("PERSIMMONROPE is not implemented for RoPETree"); + rope_permission(input, output); + } else if (pose_type_ == MLAROPE) { + throw std::runtime_error("MLAROPE is not implemented for RoPETree"); + rope_mla(input, output); + } else { + MLLM_LOG_ERROR_STREAM << "RoPE type error" << std::endl; + } +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { + for (int d = partial_dimension; d < input->dimension(); ++d) { + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, input->dataAt(n, h, s, d)); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(input->dataAt(n, h, s, d))); + } + } + } + } + } + h_cnt_ += input->sequence(); + if (h_cnt_ >= pos_max_) { + h_cnt_ = 0; + } + return Op::execute(inputs, outputs); +} +// ErrorCode CPURoPETree::updateVerifiedRoPECache(unsigned int withdrawn_length) { +// h_cnt_ -= withdrawn_length; +// if (h_cnt_ < 0) { +// h_cnt_ = 0; +// } +// return MLLM_NO_ERROR; +// } + + +ErrorCode CPURoPETree::load(AbstructLoader &loader) { + return Op::load(loader); +} +ErrorCode CPURoPETree::free(vector> inputs, vector> outputs) { + return Op::free(inputs, outputs); +} +} // namespace mllm diff --git a/src/backends/cpu/op/CPURoPETree.hpp b/src/backends/cpu/op/CPURoPETree.hpp new file mode 100644 index 00000000..3d9f280c --- /dev/null +++ b/src/backends/cpu/op/CPURoPETree.hpp @@ -0,0 +1,76 @@ +#ifndef MLLM_CPUROPETREE_H +#define MLLM_CPUROPETREE_H + +#include "Op.hpp" +#include "../CPUBackend.hpp" + +namespace mllm { + +class CPURoPETree final : public Op { +public: + CPURoPETree(Backend *bn, string opName, int pose_type, int threadCount); + CPURoPETree(Backend *bn, string opName, int pose_type, float rope_theta, int max_position_embeddings, int threadCount); + CPURoPETree(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, int threadCount); + CPURoPETree(Backend *bn, string opName, OpParam& config, int threadCount); + virtual ~CPURoPETree() = default; + virtual ErrorCode reshape(vector> inputs, vector> outputs) override; + virtual ErrorCode load(AbstructLoader &loader) override; + virtual ErrorCode execute(vector> inputs, vector> outputs) override; + virtual ErrorCode free(vector> inputs, vector> outputs) override; + ErrorCode doExecute(vector> inputs, vector> outputs); + +private: + // Tensor freq_; + // static Tensor sin_; + // static Tensor cos_; + static vector theta_; + static vector> sin_; + static vector> cos_; + static int global_pose_type_; + static int ishape_old; + int rope_theta_ = 10000; + int h_cnt_ = 0; + int pos_max_ = 16384; + int pose_type_ = 4; + int ishape; + int thread_count = 4; + float partial_rotary_factor_ = 1; + + OpParam config_; + + RoPEThetaType rope_type = DEFAULT; + + void rope_llama(shared_ptr input, shared_ptr output); + void rope_hf(shared_ptr input, shared_ptr output, shared_ptr tree_ancestor); + void rope_permission(shared_ptr input, shared_ptr output); + void rope_mla(shared_ptr input, shared_ptr output); + void clearCache() override { + h_cnt_ = 0; + } + ErrorCode updateVerifiedRoPECache(unsigned int draft_length); +}; + +class CPURoPETreeCreator : public CPUBackend::Creator { +public: + virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { + auto it = op_param.find("rope_type"); + if (it != op_param.end()) { + return new CPURoPETree(bn, name, op_param, threadCount); + } + + int pose_type = op_param["pose_type"]; + if (op_param.find("rope_theta") == op_param.end()) { + return new CPURoPETree(bn, name, pose_type, threadCount); + } + float rope_theta = op_param["rope_theta"]; + int max_position_embeddings = op_param["max_position_embeddings"]; + if (op_param.find("partial_rotary_factor") == op_param.end()) { + return new CPURoPETree(bn, name, pose_type, rope_theta, max_position_embeddings, threadCount); + } + float partial_rotary_factor = op_param["partial_rotary_factor"]; + return new CPURoPETree(bn, name, pose_type, rope_theta, partial_rotary_factor, max_position_embeddings, threadCount); + } +}; +} // namespace mllm + +#endif // MLLM_CPUROPETREE_H \ No newline at end of file diff --git a/src/models/qwen/modeling_qwen_sd.hpp b/src/models/qwen/modeling_qwen_sd.hpp new file mode 100644 index 00000000..0217d685 --- /dev/null +++ b/src/models/qwen/modeling_qwen_sd.hpp @@ -0,0 +1,412 @@ +/** + * @file modeling_qwen_sd.hpp + * @author Zhiyang Chen (zhiyangchen@stu.pku.edu.cn) + * @brief + * @date 2025-3-5 + * + */ +#ifndef MODELING_QWENSD_HPP +#define MODELING_QWENSD_HPP + +#include "Backend.hpp" +#include "Layer.hpp" +#include "Module.hpp" +#include "Tensor.hpp" +#include "configuration_qwen.hpp" +#include +#include "Draft.hpp" + +using namespace mllm; + +// Copied from modeling_qwen.hpp +class QWenMLP final : public Module { +public: + QWenMLP() = default; + QWenMLP(int hidden_size, int intermediate_size, const QWenNameConfig &names, + const std::string &base_name) { + gate_proj = + Linear(hidden_size, intermediate_size, false, base_name + names._gate_proj_name); + silu = SiLU(base_name + "act"); + up_proj = Linear(hidden_size, intermediate_size, false, base_name + names._up_proj_name); + down_proj = + Linear(intermediate_size, hidden_size, false, base_name + names._down_proj_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = gate_proj(inputs[0]); + x = silu(x); + auto y = up_proj(inputs[0]); + x = x * y; + x = down_proj(x); + return {x}; + } + +private: + Layer gate_proj; + Layer up_proj; + Layer down_proj; + + Layer silu; +}; + +class QWenAttention final : public Module { +public: + QWenAttention() = default; + QWenAttention(const QWenConfig &config, const QWenNameConfig &names, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // init layers + q_proj = Linear(hidden_size, num_heads * head_dim, true, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, true, + base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, true, + base_name + names._v_proj_name); + o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); + + q_rope = RoPETree(config.RoPE_type, config.rope_theta, config.max_position_embeddings, + base_name + "q_rope"); + k_rope = RoPETree(config.RoPE_type, config.rope_theta, config.max_position_embeddings, + base_name + "k_rope"); + k_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "k_cache"); + v_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "v_cache"); + // k_cache = KVCacheTree(num_key_value_groups, config.cache_limit, base_name + "k_cache"); + // v_cache = KVCacheTree(num_key_value_groups, config.cache_limit, base_name + "v_cache"); + mask = CausalTreeMask(base_name + "tree_mask"); + softmax = Softmax(DIMENSION, base_name + "softmax"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto query_states = q_proj(inputs[0]); + auto key_states = k_proj(inputs[1]); + auto value_states = v_proj(inputs[2]); + auto &tree_ancestor = inputs[3]; + + // [batch, heads, sequence, dims] + query_states = query_states.view(-1, num_heads, -1, head_dim); + key_states = key_states.view(-1, num_key_value_heads, -1, head_dim); + value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); + + // embedding + query_states = q_rope(query_states, tree_ancestor); + key_states = k_rope(key_states, tree_ancestor); + + // kv cache + key_states = k_cache(key_states); + value_states = v_cache(value_states); + + // attention weight + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + atten_weight = mask(atten_weight, k_cache.getCacheSeqLen(), tree_ancestor); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + + // attention output + auto atten_output = Tensor::mm(atten_weight, value_states); + atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); + atten_output = o_proj(atten_output); + return {atten_output}; + } + + vector get_cache() { + return {&k_cache, &v_cache}; + } + vector get_rope() { + return {&q_rope, &k_rope}; + } + +private: + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + Layer q_proj; + Layer k_proj; + Layer v_proj; + Layer o_proj; + RoPETree q_rope; + RoPETree k_rope; + KVCache k_cache; + KVCache v_cache; + // KVCacheTree k_cache; + // KVCacheTree v_cache; + CausalTreeMask mask; + Softmax softmax; +}; + +class QWenDecoder final : public Module { +public: + QWenDecoder() = default; + QWenDecoder(const QWenConfig &config, const QWenNameConfig &names, const string &base_name) { + self_atten = QWenAttention(config, names, base_name + names._attn_base_name); + mlp = QWenMLP(config.hidden_size, config.intermediate_size, names, + base_name + names._ffn_base_name); + input_layernorm = + RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + post_attention_layernorm = + RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto &tree_ancestors = inputs[1]; + auto x = input_layernorm(inputs[0]); + x = self_atten({x, x, x, tree_ancestors})[0]; + auto tmp = x + inputs[0]; + x = post_attention_layernorm(tmp); + x = mlp({x})[0]; + x = x + tmp; + return {x}; + } + + QWenAttention &get_attention() { + return self_atten; + } + +private: + QWenAttention self_atten; + QWenMLP mlp; + Layer input_layernorm; + Layer post_attention_layernorm; +}; + +class QWenModel final : public Module { +public: + QWenModel() = default; + QWenModel(const QWenConfig &config, const QWenNameConfig &names, const string &base_name) { + blocks = List(config.num_hidden_layers, config, names, base_name); + norm = RMSNorm(config.hidden_size, config.rms_norm_eps, names.post_norm_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = inputs[0]; + auto &tree_ancestors = inputs[1]; + for (auto &block : blocks) { x = block({x, tree_ancestors})[0]; } + x = norm(x); + return {x}; + } + + void clear_kvcache() override { + for (auto &block : blocks) { + auto kvcache = block.get_attention().get_cache(); + for (auto &cache : kvcache) { cache->clearCache(); } + auto ropes = block.get_attention().get_rope(); + for (auto &rope : ropes) { rope->clearCache(); } + } + } + + // void update_state(const Pool &candidates) { + // unsigned int draft_length = candidates.last_draft_length; + // unsigned int withdrawn_length = candidates.last_draft_length - candidates.last_accept_length; + + // for (auto &block : blocks) { + // auto kvcache = block.get_attention().get_cache(); + // for (auto &cache : kvcache) { cache->updateVerifiedKVCache(candidates.last_accept_position_ids, draft_length); } + // auto ropes = block.get_attention().get_rope(); + // for (auto &rope : ropes) { rope->updateVerifiedRoPECache(withdrawn_length); } + // } + // } + +private: + std::vector blocks; + Layer norm; +}; + +class QWenForCausalLM final : public Module { +public: + QWenForCausalLM(QWenConfig &config) { + auto names = config.names_config; + hidden_size = config.hidden_size; + tie_embedding_words = config.tie_embedding_words; + embedding = Embedding(config.vocab_size, config.hidden_size, names.token_embd_name); + model = QWenModel(config, names, names.blk_name); + + // Qwen-0.5 use tied embedding + // Others use nn.Linear() + if (tie_embedding_words) { + lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, + names.token_embd_name + ".weight"); + } else { + lm_head_layer = + Linear(config.hidden_size, config.vocab_size, false, names.lm_head_name); + } + + tp.is_decoding = false; + // TODO: load datastore + sa.add_tokens({40, 1079, 264, 3460}); + sa.add_tokens({13, 3017, 829, 374, 1207, 1103}); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + // if (args.size() > 0 && Tensor::tensor_status == TENSOR_STATIC_INIT) { // 只在第一次调Forward的时候更新kv cache + // try { + // Pool* candidates = std::any_cast(args[0]); + // if (candidates->is_decoding) { + // model.update_state(*candidates); + // } + // } catch (const std::bad_any_cast&) { + // } + // } + + auto x = embedding(inputs[0]); + auto &tree_ancestors = inputs[1]; + + // go through model + auto outputs = model({x, tree_ancestors})[0]; + if (tie_embedding_words) { + outputs = Tensor::mm(outputs, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION)); + } else { + outputs = lm_head_layer(outputs); + } + return {outputs}; + } + + void clear_kvcache() override { + model.clear_kvcache(); + } + + void generate( + Tensor &input_ids, const LlmTextGeneratorOpts &opt, const std::function &call_back) + override { + auto post_processing_for_SD = [](const std::vector &token_indices, const std::vector &tree_anc, + unsigned int input_length, Tensor &input_ids, Tensor &tree_ancestors, const vector &clean_tensors) + { + input_ids.reshape(1, 1, input_length, 1); + input_ids.alloc(); + tree_ancestors.reshape(1, 1, input_length, 1); + tree_ancestors.alloc(); + for (auto seq_id = 0; seq_id < input_length; ++seq_id) { + input_ids.setDataAt(0, 0, seq_id, 0, token_indices[seq_id]); + tree_ancestors.setDataAt(0, 0, seq_id, 0, tree_anc[seq_id]); + } + for (auto tensor : clean_tensors) { + tensor->reshape(0, 0, 0, 0); + tensor->alloc(); + } + }; + + if (!opt.do_sample) { + // greedy search + if (!text_generator_ || text_generator_->type() != LLmTextGeneratorType::kGreedySearch) { + text_generator_ = std::make_shared(LLmTextGeneratorType::kGreedySearchForSD, opt); + } + } else { + // TODO nucleus sample for SD + throw std::runtime_error("Not implemented yet"); + } + + std::vector seq(input_ids.sequence()); + for (int i = 0; i < input_ids.sequence(); ++i) { + auto value = input_ids.dataAt(0, 0, i, 0); + seq[i] = ((unsigned int)(value)); + } + updateContext(seq); + + tree_ancestors = Tensor(1, 1, 1, 1, input_ids.backend(), true); + tree_ancestors.setName("tree_ancestors"); + tree_ancestors.setDtype(MLLM_TYPE_I32); + tp.is_decoding = false; + static_cast(Backend::global_backends[MLLM_CPU])->setUsingDraft(false); // prefill时不使用 + static_cast(Backend::global_backends[MLLM_CPU])->setLastDraftLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setLastVerifiedPositionIds({}); + + unsigned int cur_seq_length = input_ids.sequence(); + std::vector predicted_token_ids; + + for (int step = 0; step < opt.max_new_tokens; ++step) { + auto _out = (*this)({input_ids, tree_ancestors}, &tp); // batch_size * seq_len * 1 * vocab_size + auto out_token = text_generator_->generate_SD(_out[0], tp); + tp.is_decoding = true; + + // 流式输出 + bool is_end_generate = false; + std::vector new_predicted_token_ids; + if (step > 0) { + const auto& trace = tp.get_accepted_trace(); + auto accept_length = tp.get_accepted_length(); + for (unsigned int i = 0; i < accept_length; i++) { + auto accept_token_idx = trace.trace_tokens[i]; + if (!call_back(accept_token_idx)) { + is_end_generate = true; + break; + } + cur_seq_length += 1; + new_predicted_token_ids.push_back(accept_token_idx); + } + } + if (is_end_generate || !call_back(out_token)) { + break; + } + // 至少会verifiy一个token + predicted_token_ids.push_back(out_token); + new_predicted_token_ids.push_back(out_token); + cur_seq_length += 1; + + updateContext(new_predicted_token_ids); + updateTracePool(cur_seq_length, out_token); + + // 为下一次forward准备draft + std::vector new_token_ids = {out_token}; + std::vector position_ids = {cur_seq_length - 1}; + std::vector tree_anc = {-1}; + unsigned int draft_len = tp.generate_draft(new_token_ids, position_ids, tree_anc, cur_seq_length); + post_processing_for_SD(new_token_ids, tree_anc, draft_len + 1, input_ids, tree_ancestors, {}); + + if (step == 0) { + static_cast(Backend::global_backends[MLLM_CPU])->setUsingDraft(true); + } + static_cast(Backend::global_backends[MLLM_CPU])->setLastDraftLength(tp.last_draft_length); + static_cast(Backend::global_backends[MLLM_CPU])->setLastVerifiedPositionIds(tp.last_accept_position_ids); + } + tp.reset(); + sa.reset(); + // std::cout << std::endl; + // for (int i = 0; i < predicted_token_ids.size(); i++) { + // std::cout << predicted_token_ids[i] << ' '; + // } + } + + +private: + int hidden_size; + bool tie_embedding_words; + Layer embedding; + Parameter lm_head; + Layer lm_head_layer; + QWenModel model; + + TracePool tp; + SuffixAutomaton sa; + Tensor tree_ancestors; + + void updateTracePool(unsigned int cur_seq_length, unsigned int last_token_id) { + auto [idx, len] = sa.lookup(last_token_id); + std::vector seq; + auto dlen = sa.gen_draft(seq, idx, len, last_token_id, 10); + + // TODO 保留重用策略 + tp.clear_trace(); + + if (dlen > 0) { + tp.add_trace(seq); + // std::cout << dlen << std::endl; + } + + // 测试使用固定的trace + // if (last_token_id == 40) { + // tp.add_trace({88, 3017, 829, 374, 88}); + // tp.add_trace({1079, 264, 3460, 99}); + // } + }; + + void updateContext(const std::vector &new_predicted_token_ids) { + sa.add_tokens(new_predicted_token_ids); + } + +}; + +#endif //! MODELING_QWENSD_HPP \ No newline at end of file