diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index c6758ead1c..46b7fc4117 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -254,8 +254,8 @@ class LLMChat { static_cast(kDLCPU), 0, static_cast(relax_vm::AllocatorType::kPooled)); - encoding_func_ = vm_->GetFunction("prefill"); - decoding_func_ = vm_->GetFunction("decode"); + prefill_func_ = vm_->GetFunction("prefill"); + decode_func_ = vm_->GetFunction("decode"); encoding_without_cache_func_ = vm_->GetFunction("encoding_without_cache"); softmax_func_ = vm_->GetFunction("softmax_with_temperature"); get_metadata_func_ = vm_->GetFunction("get_metadata"); @@ -374,7 +374,7 @@ class LLMChat { this->conversation_.Reset(); this->ResetRuntimeStats(); this->ResetKVCache(); - this->filled_kv_cache_len_ = 0; + this->total_seq_len_ = 0; } /*! \brief reset the runtime stats. */ @@ -405,7 +405,7 @@ class LLMChat { std::vector tokens; std::vector prompts; - if (this->filled_kv_cache_len_ == 0) { + if (this->total_seq_len_ == 0) { prompts = this->conversation_.GetPromptArray(); if (this->conversation_.add_bos) { tokens.insert(tokens.begin(), bos_token_id_); @@ -417,11 +417,11 @@ class LLMChat { std::string all_prompt = GetConcatPrompt(prompts, 0, 0); std::vector encoded = this->tokenizer_->Encode(all_prompt); tokens.insert(tokens.end(), encoded.begin(), encoded.end()); - if (this->filled_kv_cache_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) { + if (this->total_seq_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) { return tokens; } // need shift window and re-encode - this->filled_kv_cache_len_ = 0; + this->total_seq_len_ = 0; this->ResetKVCache(); tokens.clear(); if (this->conversation_.add_bos) { @@ -505,9 +505,9 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); - int32_t new_seq_len = filled_kv_cache_len_ + token_len; + int32_t new_seq_len = total_seq_len_ + token_len; NDArray logits_on_device = this->Forward(prompt_tokens, new_seq_len); - filled_kv_cache_len_ = new_seq_len; + total_seq_len_ = new_seq_len; int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); @@ -525,8 +525,8 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); - NDArray logits_on_device = this->Forward({last_token}, filled_kv_cache_len_ + 1); - filled_kv_cache_len_ += 1; + NDArray logits_on_device = this->Forward({last_token}, total_seq_len_ + 1); + total_seq_len_ += 1; int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); @@ -658,12 +658,12 @@ class LLMChat { } // resize kv to remove the context fkvcache_array_popn_(kv_cache_, backoff); - filled_kv_cache_len_ -= backoff; + total_seq_len_ -= backoff; } } } // TODO(mlc-team): add another per convo seq len trigger - if (filled_kv_cache_len_ >= max_window_size_) { + if (total_seq_len_ >= max_window_size_) { stop_triggered_ = true; } if (stop_triggered_) { @@ -674,14 +674,15 @@ class LLMChat { // run forward compute NDArray Forward(std::vector input_tokens, int64_t cur_pos) { Array ret; - if (input_tokens.size() > 1 && encoding_func_.defined()) { + if (input_tokens.size() > 1 && prefill_func_.defined()) { NDArray input_data = this->GetInputTokenNDArray(input_tokens); - ret = encoding_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_); + ret = prefill_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_); } else { + // running decode function when prefill is not available for (int i = 0; i < input_tokens.size(); ++i) { NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); int64_t pos = cur_pos + i + 1 - input_tokens.size(); - ret = decoding_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); + ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); } } return Downcast(ret[0]); @@ -780,7 +781,7 @@ class LLMChat { // conversation Conversation conversation_; // total sequence len, - int64_t filled_kv_cache_len_{0}; + int64_t total_seq_len_{0}; // max window size, mean generation length int64_t max_window_size_{768}, mean_gen_len_{128}; // shift window fill factor @@ -818,9 +819,9 @@ class LLMChat { // The vm module Module vm_; // encoding function - PackedFunc encoding_func_; + PackedFunc prefill_func_; // decoding function - PackedFunc decoding_func_; + PackedFunc decode_func_; // encoding without cache PackedFunc encoding_without_cache_func_; // softmax @@ -968,8 +969,8 @@ class LLMChatModule : public ModuleNode { static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(relax_vm::AllocatorType::kPooled)); - chat_->encoding_func_ = chat_->vm_->GetFunction("prefill"); - chat_->decoding_func_ = chat_->vm_->GetFunction("decode"); + chat_->prefill_func_ = chat_->vm_->GetFunction("prefill"); + chat_->decode_func_ = chat_->vm_->GetFunction("decode"); chat_->encoding_without_cache_func_ = chat_->vm_->GetFunction("encoding_without_cache"); chat_->softmax_func_ = chat_->vm_->GetFunction("softmax_with_temperature"); chat_->get_metadata_func_ = chat_->vm_->GetFunction("get_metadata");