Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rtp_llm/cpp/engine_base/stream/GenerateStream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,23 @@ bool GenerateStream::waitForRemoteGenerate() {
return need_remote_generate_;
}

bool GenerateStream::waitForPrefillMtpReady() {
std::unique_lock<std::mutex> lock(*output_mutex_);

cv_->wait(lock, [this] {
return prefill_mtp_ready_ || generate_status_->status == StreamState::STOPPED
|| generate_status_->status == StreamState::FINISHED;
});

if(!prefill_mtp_ready_ && generate_status_->status == StreamState::STOPPED) {
RTP_LLM_LOG_WARNING("waitForPrefillMtpReady exits due to stream [%ld] stopped, error: %s",
streamId(),
generate_status_->error_info.ToString().c_str());
}

return prefill_mtp_ready_;
}

std::vector<int> GenerateStream::getLatestTokens(size_t token_num) {
return complete_token_ids_->getLatestTokens(token_num);
}
Expand Down
7 changes: 7 additions & 0 deletions rtp_llm/cpp/engine_base/stream/GenerateStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ class GenerateStream {
need_remote_generate_ = need_remote_generate;
}

void setPrefillMtpReady(bool prefill_mtp_ready) {
prefill_mtp_ready_ = prefill_mtp_ready;
}

bool waitForPrefillMtpReady();

std::vector<int> getLatestTokens(size_t token_num);

void incBatchWithPrefillTimes(int32_t times);
Expand Down Expand Up @@ -531,6 +537,7 @@ class GenerateStream {

bool last_block_aligned_ = false;
volatile bool need_remote_generate_ = false;
bool prefill_mtp_ready_ = false;

bool gen_timeline_ = false;

Expand Down
15 changes: 15 additions & 0 deletions rtp_llm/cpp/model_rpc/DecodeRpcServerNew.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ ErrorInfo DecodeRpcServerNew::writeAppendFirstToken(DecodeGenerateContextNew& de
generate_stream->setReuseLength(generate_stream->seqLength() - 1);
generate_stream->setFallbackPrefixLength(generate_stream->reuseLength());
generate_stream->setSpEditRun(false);
generate_stream->setMtpTokenIndex(generate_stream->seqLength() - 1);
generate_stream->setContainProposeToken(true);

// Set propose tokens from prefill response
if (response.propose_token_ids_size() > 0) {
std::vector<int> propose_tokens;
propose_tokens.assign(response.propose_token_ids().begin(), response.propose_token_ids().end());
generate_stream->setProposeToken(propose_tokens);
RTP_LLM_LOG_DEBUG("request [%s] received %d propose tokens from prefill",
decode_context.request_key.c_str(),
propose_tokens.size());
Comment on lines +295 to +300
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a temporary vector and copying data is inefficient. If setProposeToken accepts a RepeatedField or can be modified to accept iterators, pass the protobuf data directly to avoid the copy.

Suggested change
std::vector<int> propose_tokens;
propose_tokens.assign(response.propose_token_ids().begin(), response.propose_token_ids().end());
generate_stream->setProposeToken(propose_tokens);
RTP_LLM_LOG_DEBUG("request [%s] received %d propose tokens from prefill",
decode_context.request_key.c_str(),
propose_tokens.size());
generate_stream->setProposeToken(response.propose_token_ids().begin(), response.propose_token_ids().end());
RTP_LLM_LOG_DEBUG("request [%s] received %d propose tokens from prefill",
decode_context.request_key.c_str(),
response.propose_token_ids_size());

Copilot uses AI. Check for mistakes.

} else {
RTP_LLM_LOG_WARNING("request [%s] MTP enabled but no propose tokens received from prefill",
decode_context.request_key.c_str());
}
}
generate_stream->resetBeginTime(currentTimeUs());

Expand Down
15 changes: 8 additions & 7 deletions rtp_llm/cpp/model_rpc/PrefillGenerateContextNew.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ ErrorInfo PrefillGenerateContextNew::init(const std::shared_ptr<EngineBase>& eng

generate_input = QueryConverter::transQuery(&request->input());
generate_input->generate_config->pd_separation = true;
generate_input->generate_config->force_disable_sp_run = true;

// TODO: support MTP
// if (engine->isMTPEagle()) {
// generate_input->generate_config->force_disable_sp_run = false;
// } else {
// generate_input->generate_config->force_disable_sp_run = true;
// }
// Configure MTP support based on request
if (engine->isMTPEagle()) {
generate_input->generate_config->force_disable_sp_run = false;
RTP_LLM_LOG_DEBUG("request [%s] MTP enabled, allowing speculative run", request_key.c_str());
} else {
generate_input->generate_config->force_disable_sp_run = true;
RTP_LLM_LOG_DEBUG("request [%s] MTP disabled or not supported", request_key.c_str());
}

stream_ = engine->makeStream(generate_input);
request_timeout_ms = stream_->getTimeoutMs();
Expand Down
60 changes: 47 additions & 13 deletions rtp_llm/cpp/model_rpc/PrefillRpcServerNew.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext*
RemoteGenerateResponsePBNew* response) {
auto modified_request = const_cast<RemoteGenerateRequestPBNew*>(request);
GenerateInputPB* mutable_input = modified_request->mutable_input();


Copy link

Copilot AI Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Unnecessary blank line added without purpose.

Suggested change

Copilot uses AI. Check for mistakes.

// reset request_id in prefill
auto request_id = loading_cache_requests_.fetch_add(1, std::memory_order_relaxed);
Expand Down Expand Up @@ -69,12 +70,6 @@ grpc::Status PrefillRpcServerNew::RemoteGenerateNew(grpc::ServerContext*
}
prefill_context.wait_store_cache_done_time_us = currentTimeUs();

// TODO: notify remote store for hidden state
// if (engine_->isMTPEagle() &&
// engine_->getDevice()->getDeviceProperties().tp_rank == 0 &&
// !request->mtp_hidden_states_key.empty()) {
//}

RTP_LLM_LOG_DEBUG("request [%s] RemoteGenerateNew success, response is %s",
prefill_context.request_key.c_str(),
response->ShortDebugString().c_str());
Expand Down Expand Up @@ -109,11 +104,6 @@ bool PrefillRpcServerNew::validRequest(PrefillGenerateContextNew& prefill_contex
return false;
}

if (engine_->isMTPEagle() && request->mtp_hidden_states_key().empty()) {
RTP_LLM_LOG_WARNING("request [%s] mtp_hidden_states_key is empty", prefill_context.request_key.c_str());
return false;
}

if (request->use_mla() != maga_init_params_.gpt_init_parameter.use_mla_) {
RTP_LLM_LOG_WARNING("request [%s] request is invalid, mla config not match",
prefill_context.request_key.c_str());
Expand Down Expand Up @@ -243,6 +233,7 @@ void PrefillRpcServerNew::constructRemoteLoadRequest(PrefillGenerateContextNew&
ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& prefill_context) {
auto stream = prefill_context.getStream();
engine_->enqueue(stream);

Copy link

Copilot AI Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Unnecessary trailing whitespace added.

Suggested change

Copilot uses AI. Check for mistakes.

while (!stream->finished() || stream->hasOutput()) {
const auto result = stream->nextOutput();
if (!result.ok()) {
Expand All @@ -254,15 +245,34 @@ ErrorInfo PrefillRpcServerNew::generateFirstToken(PrefillGenerateContextNew& pre
auto response_output = prefill_context.response->mutable_output();
QueryConverter::transResponse(
response_output, &(result.value()), maga_init_params_.gpt_init_parameter.misc_config.aux_string);
// should only generate one token
break;

if (engine_->isMTPEagle()) {
stream->waitForPrefillMtpReady();
break;
Copy link

Copilot AI Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The break statement causes the loop to exit early only for MTP streams, potentially skipping output processing for non-finished MTP streams. The loop should continue processing outputs until the stream is finished, regardless of MTP status.

Suggested change
break;

Copilot uses AI. Check for mistakes.

}
}

if (prefill_context.getStream()->finished()) {
RTP_LLM_LOG_INFO("request [%s] generate first token success and finished", prefill_context.request_key.c_str());
}
auto first_token = prefill_context.getStream()->currentExecuteTokens()[0];
prefill_context.response->set_finished(prefill_context.getStream()->finished());
prefill_context.response->set_first_generate_token_id(first_token);
if (engine_->isMTPEagle()) {
auto stream = prefill_context.getStream();
auto propose_tokens = stream->getProposeToken();
if (!propose_tokens.empty()) {
prefill_context.response->mutable_propose_token_ids()->CopyFrom(
{propose_tokens.begin(), propose_tokens.end()});
Comment on lines +265 to +266
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CopyFrom operation with initializer list creates unnecessary temporary objects. Use RepeatedField::Add() in a loop or assign directly to avoid the temporary vector creation.

Copilot uses AI. Check for mistakes.

RTP_LLM_LOG_DEBUG("request [%s] added %d propose tokens",
prefill_context.request_key.c_str(),
propose_tokens.size());
} else {
RTP_LLM_LOG_WARNING("request [%s] MTP enabled but no propose tokens generated",
prefill_context.request_key.c_str());
}
}

return ErrorInfo::OkStatus();
}

Expand Down Expand Up @@ -437,6 +447,30 @@ grpc::Status PrefillRpcServerNew::RemoteStore(grpc::ServerContext* server
}
}
}
// MTP kv cache mappings
if (engine_->isMTPEagle() && propose_maga_init_params_) {
for (size_t mtp_model_id = 0; mtp_model_id < propose_maga_init_params_->mtp_model_params_->size(); mtp_model_id++) {
EngineInitParams* mtp_engine_init_params = propose_maga_init_params_->mtp_model_params_->at(mtp_model_id).get();
size_t mtp_layer_num = mtp_engine_init_params->gpt_init_parameter.num_layers_;
size_t mtp_model_id_val = mtp_engine_init_params->model_id;

for (size_t layer_id = 0; layer_id < mtp_layer_num; layer_id++) {
auto block_num = request->decode_block_ids_size();
for (int i = 0; i < block_num; i++) {
auto decode_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->decode_block_ids(i)), layer_id);
auto prefill_block_key = makeCacheKey(mtp_model_id_val, std::to_string(request->prefill_block_ids(i)), layer_id);

store_request->buffer_pairs["k_" + prefill_block_key] = "k_" + decode_block_key;
if (!engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
store_request->buffer_pairs["v_" + prefill_block_key] = "v_" + decode_block_key;
}
}
}
}
RTP_LLM_LOG_DEBUG("request [%s] added MTP cache mappings for %d models",
request->request_key().c_str(),
propose_maga_init_params_->mtp_model_params_->size());
}

auto collector = std::make_shared<CacheStoreRemoteStoreMetricsCollector>(metrics_reporter_,
store_request->buffer_pairs.size());
Expand Down
2 changes: 1 addition & 1 deletion rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ message RemoteGenerateRequestPBNew {
repeated string addrs = 4;
repeated int32 block_ids = 5;
int64 reuse_block_size = 6;
string mtp_hidden_states_key = 7;
bool use_mla = 8;
int32 layer_num = 9;
int64 deadline_us = 10;
Expand All @@ -272,6 +271,7 @@ message RemoteGenerateResponsePBNew {
int32 first_generate_token_id = 3;
int64 first_token_rt_us = 4;
bool finished = 5;
repeated int32 propose_token_ids = 6;
}

message RemoteStorePartition {
Expand Down
8 changes: 8 additions & 0 deletions rtp_llm/cpp/normal_engine/NormalEngine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,14 @@ absl::Status NormalEngine::step() {
}
int64_t step_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
absl::Status status = executor_->process(streams);

for (auto& stream : streams) {
// set need_remote_generate for pd-sep stream
if (!stream->finished() && stream->queryPdSep()) {
RTP_LLM_LOG_DEBUG("stream [%ld] set need_remote_generate", stream->streamId());
stream->setNeedRemoteGenerate(true);
}
}

if (nullptr != profiler_) {
profiler_step_--;
Expand Down
4 changes: 0 additions & 4 deletions rtp_llm/cpp/normal_engine/NormalGenerateStream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,6 @@ void NormalGenerateStream::updateOutput(const StreamUpdateInfo& update_info) {
queryPdSep(),
isStreaming(),
update_info.update_remote_generate);
if (!finished_ && queryPdSep() && update_info.update_remote_generate) {
RTP_LLM_LOG_DEBUG("stream [%ld] set need_remote_generate", streamId());
setNeedRemoteGenerateWithoutLock(true);
}

bool pd_sep_first_token = queryPdSep();
bool need_update = pd_sep_first_token || isStreaming() || finished_;
Expand Down
1 change: 1 addition & 0 deletions rtp_llm/cpp/speculative_engine/SpeculativeEngine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ absl::Status SpeculativeEngine::prefillMtpStep(std::list<GenerateStreamPtr>& str
RTP_LLM_LOG_DEBUG("stream [%ld] set setNeedRemoteGenerate", stream->streamId());
stream->setNeedRemoteGenerate(true);
}
stream->setPrefillMtpReady(true);
}
}

Expand Down
Loading