Skip to content

Commit 2b6b7d5

Browse files
stateful gpu: use full chat history for each prefill (#668)
1 parent 921da08 commit 2b6b7d5

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -419,17 +419,23 @@ void OVInferRequest::QueryStatus() {
419419
<< " ";
420420
}
421421

422-
void StatefulOVInferRequest::_pre_infer() {
422+
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string d)
423+
: OVInferRequest(std::move(infer_request)), device(d) {
424+
if ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)) {
425+
prefill_use_full_chat_history = true;
426+
}
427+
}
428+
429+
void StatefulOVInferRequest::PreProcessInferRequest() {
423430
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
424431
// as a workaround.
425432
// TODO: Fix this.
426433
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
427434
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
428435
ovInfReq.set_tensor("beam_idx", beam_idx);
429436

430-
// For NPU, we need to cache input_ids and position_ids for
431-
// chat-mode support.
432-
if (device.find("NPU") != std::string::npos) {
437+
// If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
438+
if (prefill_use_full_chat_history) {
433439
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
434440

435441
// add input_ids to our cache
@@ -454,6 +460,9 @@ void StatefulOVInferRequest::_pre_infer() {
454460
// if the input_ids size doesn't equal cached size of the input_ids
455461
// then it means that we're running 2nd (or later) prompt.
456462
if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) {
463+
// Clear the internal KVCache state (note: this is a no-op for NPU)
464+
ovInfReq.reset_state();
465+
457466
// set a new input_ids tensor with the content of our cached input_ids
458467
{
459468
auto new_shape = input_ids_tensor.get_shape();
@@ -480,19 +489,22 @@ void StatefulOVInferRequest::_pre_infer() {
480489
}
481490

482491
void StatefulOVInferRequest::StartAsync() {
483-
_pre_infer();
492+
PreProcessInferRequest();
484493
OVInferRequest::StartAsync();
485494
}
486495

487496
void StatefulOVInferRequest::Infer() {
488-
_pre_infer();
497+
PreProcessInferRequest();
489498
OVInferRequest::Infer();
490499
}
491500

492501
void StatefulOVInferRequest::RewindKVCache(size_t index) {
493-
if (device == "NPU") {
494-
std::cout << "RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
495-
<< index << std::endl;
502+
LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index << std::endl;
503+
504+
if (prefill_use_full_chat_history) {
505+
// Clear the internal KVCache state (note: this is a no-op for NPU)
506+
ovInfReq.reset_state();
507+
496508
if (cached_input_ids.size() > index) {
497509
cached_input_ids.resize(index);
498510
}
@@ -501,8 +513,6 @@ void StatefulOVInferRequest::RewindKVCache(size_t index) {
501513
cached_position_ids.resize(index);
502514
}
503515
} else {
504-
std::cout << "OVInferRequest::RewindKVCache: Trimming internal states to length = "
505-
<< index << std::endl;
506516
if (index == 0) {
507517
// in this case, since we're trimming *all* of the KVCache, just reset the state.
508518
ovInfReq.reset_state();

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,19 @@ class OVInferRequest {
142142

143143
class StatefulOVInferRequest : public OVInferRequest {
144144
public:
145-
explicit StatefulOVInferRequest(ov::InferRequest obj, std::string d) : OVInferRequest(std::move(obj)), device(d) {}
145+
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device);
146146

147147
void StartAsync() override;
148148
void Infer() override;
149149
void RewindKVCache(size_t index) override;
150150

151151
private:
152-
void _pre_infer();
152+
void PreProcessInferRequest();
153153
std::string device;
154154

155-
// For NPU, we need to cache input_ids & position_ids to support chat-mode.
155+
// If prefill_use_full_chat_history is true, cache input_ids & position_ids,
156+
// and ensure that the full chat history is passed to each prefill.
157+
bool prefill_use_full_chat_history = false;
156158
std::vector<int64_t> cached_input_ids;
157159
std::vector<int64_t> cached_position_ids;
158160
};

0 commit comments

Comments
 (0)