From 040e6fb0c198a4be1109a8b7048c88e3b3b9e3f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=9D=89=E5=90=9F?= Date: Thu, 16 Oct 2025 16:47:56 +0800 Subject: [PATCH] feat: decode entrance support mtp --- .../cpp/engine_base/stream/GenerateStream.cc | 17 ++++++ .../cpp/engine_base/stream/GenerateStream.h | 7 +++ rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc | 15 +++++ .../model_rpc/PrefillGenerateContextNew.cc | 15 ++--- rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc | 60 +++++++++++++++---- .../model_rpc/proto/model_rpc_service.proto | 2 +- rtp_llm/cpp/normal_engine/NormalEngine.cc | 8 +++ .../cpp/normal_engine/NormalGenerateStream.cc | 4 -- .../speculative_engine/SpeculativeEngine.cc | 1 + 9 files changed, 104 insertions(+), 25 deletions(-) diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStream.cc b/rtp_llm/cpp/engine_base/stream/GenerateStream.cc index ee8319047..62ff18988 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStream.cc +++ b/rtp_llm/cpp/engine_base/stream/GenerateStream.cc @@ -770,6 +770,23 @@ bool GenerateStream::waitForRemoteGenerate() { return need_remote_generate_; } +bool GenerateStream::waitForPrefillMtpReady() { + std::unique_lock lock(*output_mutex_); + + cv_->wait(lock, [this] { + return prefill_mtp_ready_ || generate_status_->status == StreamState::STOPPED + || generate_status_->status == StreamState::FINISHED; + }); + + if(!prefill_mtp_ready_ && generate_status_->status == StreamState::STOPPED) { + RTP_LLM_LOG_WARNING("waitForPrefillMtpReady exits due to stream [%ld] stopped, error: %s", + streamId(), + generate_status_->error_info.ToString().c_str()); + } + + return prefill_mtp_ready_; +} + std::vector GenerateStream::getLatestTokens(size_t token_num) { return complete_token_ids_->getLatestTokens(token_num); } diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStream.h b/rtp_llm/cpp/engine_base/stream/GenerateStream.h index 16bbcb3fa..57fd73df5 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStream.h +++ b/rtp_llm/cpp/engine_base/stream/GenerateStream.h @@ -323,6 +323,12 @@ class GenerateStream { need_remote_generate_ = need_remote_generate; } + void setPrefillMtpReady(bool prefill_mtp_ready) { + prefill_mtp_ready_ = prefill_mtp_ready; + } + + bool waitForPrefillMtpReady(); + std::vector getLatestTokens(size_t token_num); void incBatchWithPrefillTimes(int32_t times); @@ -531,6 +537,7 @@ class GenerateStream { bool last_block_aligned_ = false; volatile bool need_remote_generate_ = false; + bool prefill_mtp_ready_ = false; bool gen_timeline_ = false; diff --git a/rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc b/rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc index 395cdcf02..bd4c365a0 100644 --- a/rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc +++ b/rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc @@ -287,6 +287,21 @@ ErrorInfo DecodeRpcServerNew::writeAppendFirstToken(DecodeGenerateContextNew& de generate_stream->setReuseLength(generate_stream->seqLength() - 1); generate_stream->setFallbackPrefixLength(generate_stream->reuseLength()); generate_stream->setSpEditRun(false); + generate_stream->setMtpTokenIndex(generate_stream->seqLength() - 1); + generate_stream->setContainProposeToken(true); + + // Set propose tokens from prefill response + if (response.propose_token_ids_size() > 0) { + std::vector propose_tokens; + propose_tokens.assign(response.propose_token_ids().begin(), response.propose_token_ids().end()); + generate_stream->setProposeToken(propose_tokens); + RTP_LLM_LOG_DEBUG("request [%s] received %d propose tokens from prefill", + decode_context.request_key.c_str(), + propose_tokens.size()); + } else { + RTP_LLM_LOG_WARNING("request [%s] MTP enabled but no propose tokens received from prefill", + decode_context.request_key.c_str()); + } } generate_stream->resetBeginTime(currentTimeUs()); diff --git a/rtp_llm/cpp/model_rpc/PrefillGenerateContextNew.cc b/rtp_llm/cpp/model_rpc/PrefillGenerateContextNew.cc index 5f2e9ce74..32afeedef 100644 --- a/rtp_llm/cpp/model_rpc/PrefillGenerateContextNew.cc +++ b/rtp_llm/cpp/model_rpc/PrefillGenerateContextNew.cc @@ -10,14 +10,15 @@ ErrorInfo PrefillGenerateContextNew::init(const std::shared_ptr& eng generate_input = QueryConverter::transQuery(&request->input()); generate_input->generate_config->pd_separation = true; - generate_input->generate_config->force_disable_sp_run = true; - // TODO: support MTP - // if (engine->isMTPEagle()) { - // generate_input->generate_config->force_disable_sp_run = false; - // } else { - // generate_input->generate_config->force_disable_sp_run = true; - // } + // Configure MTP support based on request + if (engine->isMTPEagle()) { + generate_input->generate_config->force_disable_sp_run = false; + RTP_LLM_LOG_DEBUG("request [%s] MTP enabled, allowing speculative run", request_key.c_str()); + } else { + generate_input->generate_config->force_disable_sp_run = true; + RTP_LLM_LOG_DEBUG("request [%s] MTP disabled or not supported", request_key.c_str()); + } stream_ = engine->makeStream(generate_input); request_timeout_ms = stream_->getTimeoutMs(); diff --git a/rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc b/rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc index 5e474ebdd..2da8c1d78 100644 --- a/rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc +++ b/rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc @@ -17,6 +17,7 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext* RemoteGenerateResponsePBNew* response) { auto modified_request = const_cast(request); GenerateInputPB* mutable_input = modified_request->mutable_input(); + // reset request_id in prefill auto request_id = loading_cache_requests_.fetch_add(1, std::memory_order_relaxed); @@ -69,12 +70,6 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext* } prefill_context.wait_store_cache_done_time_us = currentTimeUs(); - // TODO: notify remote store for hidden state - // if (engine_->isMTPEagle() && - // engine_->getDevice()->getDeviceProperties().tp_rank == 0 && - // !request->mtp_hidden_states_key.empty()) { - //} - RTP_LLM_LOG_DEBUG("request [%s] RemoteGenerateNew success, response is %s", prefill_context.request_key.c_str(), response->ShortDebugString().c_str()); @@ -109,11 +104,6 @@ bool PrefillRpcServerNew::validRequest(PrefillGenerateContextNew& prefill_contex return false; } - if (engine_->isMTPEagle() && request->mtp_hidden_states_key().empty()) { - RTP_LLM_LOG_WARNING("request [%s] mtp_hidden_states_key is empty", prefill_context.request_key.c_str()); - return false; - } - if (request->use_mla() != maga_init_params_.gpt_init_parameter.use_mla_) { RTP_LLM_LOG_WARNING("request [%s] request is invalid, mla config not match", prefill_context.request_key.c_str()); @@ -243,6 +233,7 @@ void PrefillRpcServerNew::constructRemoteLoadRequest(PrefillGenerateContextNew& ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& prefill_context) { auto stream = prefill_context.getStream(); engine_->enqueue(stream); + while (!stream->finished() || stream->hasOutput()) { const auto result = stream->nextOutput(); if (!result.ok()) { @@ -254,15 +245,34 @@ ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& pre auto response_output = prefill_context.response->mutable_output(); QueryConverter::transResponse( response_output, &(result.value()), maga_init_params_.gpt_init_parameter.misc_config.aux_string); - // should only generate one token - break; + + if (engine_->isMTPEagle()) { + stream->waitForPrefillMtpReady(); + break; + } } + if (prefill_context.getStream()->finished()) { RTP_LLM_LOG_INFO("request [%s] generate first token success and finished", prefill_context.request_key.c_str()); } auto first_token = prefill_context.getStream()->currentExecuteTokens()[0]; prefill_context.response->set_finished(prefill_context.getStream()->finished()); prefill_context.response->set_first_generate_token_id(first_token); + if (engine_->isMTPEagle()) { + auto stream = prefill_context.getStream(); + auto propose_tokens = stream->getProposeToken(); + if (!propose_tokens.empty()) { + prefill_context.response->mutable_propose_token_ids()->CopyFrom( + {propose_tokens.begin(), propose_tokens.end()}); + RTP_LLM_LOG_DEBUG("request [%s] added %d propose tokens", + prefill_context.request_key.c_str(), + propose_tokens.size()); + } else { + RTP_LLM_LOG_WARNING("request [%s] MTP enabled but no propose tokens generated", + prefill_context.request_key.c_str()); + } + } + return ErrorInfo::OkStatus(); } @@ -437,6 +447,30 @@ grpc::Status PrefillRpcServerNew::RemoteStore(grpc::ServerContext* server } } } + // MTP kv cache mappings + if (engine_->isMTPEagle() && propose_maga_init_params_) { + for (size_t mtp_model_id = 0; mtp_model_id < propose_maga_init_params_->mtp_model_params_->size(); mtp_model_id++) { + EngineInitParams* mtp_engine_init_params = propose_maga_init_params_->mtp_model_params_->at(mtp_model_id).get(); + size_t mtp_layer_num = mtp_engine_init_params->gpt_init_parameter.num_layers_; + size_t mtp_model_id_val = mtp_engine_init_params->model_id; + + for (size_t layer_id = 0; layer_id < mtp_layer_num; layer_id++) { + auto block_num = request->decode_block_ids_size(); + for (int i = 0; i < block_num; i++) { + auto decode_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->decode_block_ids(i)), layer_id); + auto prefill_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->prefill_block_ids(i)), layer_id); + + store_request->buffer_pairs["k_" + prefill_block_key] = "k_" + decode_block_key; + if (!engine_->resourceContext().cache_manager->cacheConfig().use_mla) { + store_request->buffer_pairs["v_" + prefill_block_key] = "v_" + decode_block_key; + } + } + } + } + RTP_LLM_LOG_DEBUG("request [%s] added MTP cache mappings for %d models", + request->request_key().c_str(), + propose_maga_init_params_->mtp_model_params_->size()); + } auto collector = std::make_shared(metrics_reporter_, store_request->buffer_pairs.size()); diff --git a/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto b/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto index 4db6b364b..1d7a9ce94 100644 --- a/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto +++ b/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto @@ -260,7 +260,6 @@ message RemoteGenerateRequestPBNew { repeated string addrs = 4; repeated int32 block_ids = 5; int64 reuse_block_size = 6; - string mtp_hidden_states_key = 7; bool use_mla = 8; int32 layer_num = 9; int64 deadline_us = 10; @@ -272,6 +271,7 @@ message RemoteGenerateResponsePBNew { int32 first_generate_token_id = 3; int64 first_token_rt_us = 4; bool finished = 5; + repeated int32 propose_token_ids = 6; } message RemoteStorePartition { diff --git a/rtp_llm/cpp/normal_engine/NormalEngine.cc b/rtp_llm/cpp/normal_engine/NormalEngine.cc index f97b02ef8..5d0d11cc4 100644 --- a/rtp_llm/cpp/normal_engine/NormalEngine.cc +++ b/rtp_llm/cpp/normal_engine/NormalEngine.cc @@ -324,6 +324,14 @@ absl::Status NormalEngine::step() { } int64_t step_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); absl::Status status = executor_->process(streams); + + for (auto& stream : streams) { + // set need_remote_generate for pd-sep stream + if (!stream->finished() && stream->queryPdSep()) { + RTP_LLM_LOG_DEBUG("stream [%ld] set need_remote_generate", stream->streamId()); + stream->setNeedRemoteGenerate(true); + } + } if (nullptr != profiler_) { profiler_step_--; diff --git a/rtp_llm/cpp/normal_engine/NormalGenerateStream.cc b/rtp_llm/cpp/normal_engine/NormalGenerateStream.cc index c9cfb236f..3a4529bbf 100644 --- a/rtp_llm/cpp/normal_engine/NormalGenerateStream.cc +++ b/rtp_llm/cpp/normal_engine/NormalGenerateStream.cc @@ -202,10 +202,6 @@ void NormalGenerateStream::updateOutput(const StreamUpdateInfo& update_info) { queryPdSep(), isStreaming(), update_info.update_remote_generate); - if (!finished_ && queryPdSep() && update_info.update_remote_generate) { - RTP_LLM_LOG_DEBUG("stream [%ld] set need_remote_generate", streamId()); - setNeedRemoteGenerateWithoutLock(true); - } bool pd_sep_first_token = queryPdSep(); bool need_update = pd_sep_first_token || isStreaming() || finished_; diff --git a/rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc b/rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc index ff2baf2a7..baab42729 100644 --- a/rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc +++ b/rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc @@ -570,6 +570,7 @@ absl::Status SpeculativeEngine::prefillMtpStep(std::list& str RTP_LLM_LOG_DEBUG("stream [%ld] set setNeedRemoteGenerate", stream->streamId()); stream->setNeedRemoteGenerate(true); } + stream->setPrefillMtpReady(true); } }