-
Notifications
You must be signed in to change notification settings - Fork 103
feat: decode entrance support mtp #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -17,6 +17,7 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext* | |||
RemoteGenerateResponsePBNew* response) { | ||||
auto modified_request = const_cast<RemoteGenerateRequestPBNew*>(request); | ||||
GenerateInputPB* mutable_input = modified_request->mutable_input(); | ||||
|
||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Unnecessary blank line added without purpose.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||
// reset request_id in prefill | ||||
auto request_id = loading_cache_requests_.fetch_add(1, std::memory_order_relaxed); | ||||
|
@@ -69,12 +70,6 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext* | |||
} | ||||
prefill_context.wait_store_cache_done_time_us = currentTimeUs(); | ||||
|
||||
// TODO: notify remote store for hidden state | ||||
// if (engine_->isMTPEagle() && | ||||
// engine_->getDevice()->getDeviceProperties().tp_rank == 0 && | ||||
// !request->mtp_hidden_states_key.empty()) { | ||||
//} | ||||
|
||||
RTP_LLM_LOG_DEBUG("request [%s] RemoteGenerateNew success, response is %s", | ||||
prefill_context.request_key.c_str(), | ||||
response->ShortDebugString().c_str()); | ||||
|
@@ -109,11 +104,6 @@ bool PrefillRpcServerNew::validRequest(PrefillGenerateContextNew& prefill_contex | |||
return false; | ||||
} | ||||
|
||||
if (engine_->isMTPEagle() && request->mtp_hidden_states_key().empty()) { | ||||
RTP_LLM_LOG_WARNING("request [%s] mtp_hidden_states_key is empty", prefill_context.request_key.c_str()); | ||||
return false; | ||||
} | ||||
|
||||
if (request->use_mla() != maga_init_params_.gpt_init_parameter.use_mla_) { | ||||
RTP_LLM_LOG_WARNING("request [%s] request is invalid, mla config not match", | ||||
prefill_context.request_key.c_str()); | ||||
|
@@ -243,6 +233,7 @@ void PrefillRpcServerNew::constructRemoteLoadRequest(PrefillGenerateContextNew& | |||
ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& prefill_context) { | ||||
auto stream = prefill_context.getStream(); | ||||
engine_->enqueue(stream); | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Unnecessary trailing whitespace added.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||
while (!stream->finished() || stream->hasOutput()) { | ||||
const auto result = stream->nextOutput(); | ||||
if (!result.ok()) { | ||||
|
@@ -254,15 +245,34 @@ ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& pre | |||
auto response_output = prefill_context.response->mutable_output(); | ||||
QueryConverter::transResponse( | ||||
response_output, &(result.value()), maga_init_params_.gpt_init_parameter.misc_config.aux_string); | ||||
// should only generate one token | ||||
break; | ||||
|
||||
if (engine_->isMTPEagle()) { | ||||
stream->waitForPrefillMtpReady(); | ||||
break; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The break statement causes the loop to exit early only for MTP streams, potentially skipping output processing for non-finished MTP streams. The loop should continue processing outputs until the stream is finished, regardless of MTP status.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||
} | ||||
} | ||||
|
||||
if (prefill_context.getStream()->finished()) { | ||||
RTP_LLM_LOG_INFO("request [%s] generate first token success and finished", prefill_context.request_key.c_str()); | ||||
} | ||||
auto first_token = prefill_context.getStream()->currentExecuteTokens()[0]; | ||||
prefill_context.response->set_finished(prefill_context.getStream()->finished()); | ||||
prefill_context.response->set_first_generate_token_id(first_token); | ||||
if (engine_->isMTPEagle()) { | ||||
auto stream = prefill_context.getStream(); | ||||
auto propose_tokens = stream->getProposeToken(); | ||||
if (!propose_tokens.empty()) { | ||||
prefill_context.response->mutable_propose_token_ids()->CopyFrom( | ||||
{propose_tokens.begin(), propose_tokens.end()}); | ||||
Comment on lines
+265
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CopyFrom operation with initializer list creates unnecessary temporary objects. Use RepeatedField::Add() in a loop or assign directly to avoid the temporary vector creation. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||
RTP_LLM_LOG_DEBUG("request [%s] added %d propose tokens", | ||||
prefill_context.request_key.c_str(), | ||||
propose_tokens.size()); | ||||
} else { | ||||
RTP_LLM_LOG_WARNING("request [%s] MTP enabled but no propose tokens generated", | ||||
prefill_context.request_key.c_str()); | ||||
} | ||||
} | ||||
|
||||
return ErrorInfo::OkStatus(); | ||||
} | ||||
|
||||
|
@@ -437,6 +447,30 @@ grpc::Status PrefillRpcServerNew::RemoteStore(grpc::ServerContext* server | |||
} | ||||
} | ||||
} | ||||
// MTP kv cache mappings | ||||
if (engine_->isMTPEagle() && propose_maga_init_params_) { | ||||
for (size_t mtp_model_id = 0; mtp_model_id < propose_maga_init_params_->mtp_model_params_->size(); mtp_model_id++) { | ||||
EngineInitParams* mtp_engine_init_params = propose_maga_init_params_->mtp_model_params_->at(mtp_model_id).get(); | ||||
size_t mtp_layer_num = mtp_engine_init_params->gpt_init_parameter.num_layers_; | ||||
size_t mtp_model_id_val = mtp_engine_init_params->model_id; | ||||
|
||||
for (size_t layer_id = 0; layer_id < mtp_layer_num; layer_id++) { | ||||
auto block_num = request->decode_block_ids_size(); | ||||
for (int i = 0; i < block_num; i++) { | ||||
auto decode_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->decode_block_ids(i)), layer_id); | ||||
auto prefill_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->prefill_block_ids(i)), layer_id); | ||||
|
||||
store_request->buffer_pairs["k_" + prefill_block_key] = "k_" + decode_block_key; | ||||
if (!engine_->resourceContext().cache_manager->cacheConfig().use_mla) { | ||||
store_request->buffer_pairs["v_" + prefill_block_key] = "v_" + decode_block_key; | ||||
} | ||||
} | ||||
} | ||||
} | ||||
RTP_LLM_LOG_DEBUG("request [%s] added MTP cache mappings for %d models", | ||||
request->request_key().c_str(), | ||||
propose_maga_init_params_->mtp_model_params_->size()); | ||||
} | ||||
|
||||
auto collector = std::make_shared<CacheStoreRemoteStoreMetricsCollector>(metrics_reporter_, | ||||
store_request->buffer_pairs.size()); | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a temporary vector and copying data is inefficient. If setProposeToken accepts a RepeatedField or can be modified to accept iterators, pass the protobuf data directly to avoid the copy.
Copilot uses AI. Check for mistakes.