Skip to content

Commit

Permalink
Minor terminology updates (mlc-ai#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored May 29, 2023
1 parent 642669d commit 60e2176
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class LLMChat {
static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));

encoding_func_ = vm_->GetFunction("prefill");
decoding_func_ = vm_->GetFunction("decode");
prefill_func_ = vm_->GetFunction("prefill");
decode_func_ = vm_->GetFunction("decode");
encoding_without_cache_func_ = vm_->GetFunction("encoding_without_cache");
softmax_func_ = vm_->GetFunction("softmax_with_temperature");
get_metadata_func_ = vm_->GetFunction("get_metadata");
Expand Down Expand Up @@ -374,7 +374,7 @@ class LLMChat {
this->conversation_.Reset();
this->ResetRuntimeStats();
this->ResetKVCache();
this->filled_kv_cache_len_ = 0;
this->total_seq_len_ = 0;
}

/*! \brief reset the runtime stats. */
Expand Down Expand Up @@ -405,7 +405,7 @@ class LLMChat {
std::vector<int32_t> tokens;
std::vector<std::string> prompts;

if (this->filled_kv_cache_len_ == 0) {
if (this->total_seq_len_ == 0) {
prompts = this->conversation_.GetPromptArray();
if (this->conversation_.add_bos) {
tokens.insert(tokens.begin(), bos_token_id_);
Expand All @@ -417,11 +417,11 @@ class LLMChat {
std::string all_prompt = GetConcatPrompt(prompts, 0, 0);
std::vector<int32_t> encoded = this->tokenizer_->Encode(all_prompt);
tokens.insert(tokens.end(), encoded.begin(), encoded.end());
if (this->filled_kv_cache_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) {
if (this->total_seq_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) {
return tokens;
}
// need shift window and re-encode
this->filled_kv_cache_len_ = 0;
this->total_seq_len_ = 0;
this->ResetKVCache();
tokens.clear();
if (this->conversation_.add_bos) {
Expand Down Expand Up @@ -505,9 +505,9 @@ class LLMChat {

auto tstart = std::chrono::high_resolution_clock::now();

int32_t new_seq_len = filled_kv_cache_len_ + token_len;
int32_t new_seq_len = total_seq_len_ + token_len;
NDArray logits_on_device = this->Forward(prompt_tokens, new_seq_len);
filled_kv_cache_len_ = new_seq_len;
total_seq_len_ = new_seq_len;

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);

Expand All @@ -525,8 +525,8 @@ class LLMChat {

auto tstart = std::chrono::high_resolution_clock::now();

NDArray logits_on_device = this->Forward({last_token}, filled_kv_cache_len_ + 1);
filled_kv_cache_len_ += 1;
NDArray logits_on_device = this->Forward({last_token}, total_seq_len_ + 1);
total_seq_len_ += 1;

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);

Expand Down Expand Up @@ -658,12 +658,12 @@ class LLMChat {
}
// resize kv to remove the context
fkvcache_array_popn_(kv_cache_, backoff);
filled_kv_cache_len_ -= backoff;
total_seq_len_ -= backoff;
}
}
}
// TODO(mlc-team): add another per convo seq len trigger
if (filled_kv_cache_len_ >= max_window_size_) {
if (total_seq_len_ >= max_window_size_) {
stop_triggered_ = true;
}
if (stop_triggered_) {
Expand All @@ -674,14 +674,15 @@ class LLMChat {
// run forward compute
NDArray Forward(std::vector<int32_t> input_tokens, int64_t cur_pos) {
Array<ObjectRef> ret;
if (input_tokens.size() > 1 && encoding_func_.defined()) {
if (input_tokens.size() > 1 && prefill_func_.defined()) {
NDArray input_data = this->GetInputTokenNDArray(input_tokens);
ret = encoding_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_);
ret = prefill_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_);
} else {
// running decode function when prefill is not available
for (int i = 0; i < input_tokens.size(); ++i) {
NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]});
int64_t pos = cur_pos + i + 1 - input_tokens.size();
ret = decoding_func_(input_data, ShapeTuple({pos}), kv_cache_, params_);
ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_);
}
}
return Downcast<NDArray>(ret[0]);
Expand Down Expand Up @@ -780,7 +781,7 @@ class LLMChat {
// conversation
Conversation conversation_;
// total sequence len,
int64_t filled_kv_cache_len_{0};
int64_t total_seq_len_{0};
// max window size, mean generation length
int64_t max_window_size_{768}, mean_gen_len_{128};
// shift window fill factor
Expand Down Expand Up @@ -818,9 +819,9 @@ class LLMChat {
// The vm module
Module vm_;
// encoding function
PackedFunc encoding_func_;
PackedFunc prefill_func_;
// decoding function
PackedFunc decoding_func_;
PackedFunc decode_func_;
// encoding without cache
PackedFunc encoding_without_cache_func_;
// softmax
Expand Down Expand Up @@ -968,8 +969,8 @@ class LLMChatModule : public ModuleNode {
static_cast<int>(relax_vm::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));

chat_->encoding_func_ = chat_->vm_->GetFunction("prefill");
chat_->decoding_func_ = chat_->vm_->GetFunction("decode");
chat_->prefill_func_ = chat_->vm_->GetFunction("prefill");
chat_->decode_func_ = chat_->vm_->GetFunction("decode");
chat_->encoding_without_cache_func_ = chat_->vm_->GetFunction("encoding_without_cache");
chat_->softmax_func_ = chat_->vm_->GetFunction("softmax_with_temperature");
chat_->get_metadata_func_ = chat_->vm_->GetFunction("get_metadata");
Expand Down

0 comments on commit 60e2176

Please sign in to comment.