@@ -419,17 +419,23 @@ void OVInferRequest::QueryStatus() {
419
419
<< " " ;
420
420
}
421
421
422
- void StatefulOVInferRequest::_pre_infer () {
422
+ StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string d)
423
+ : OVInferRequest(std::move(infer_request)), device(d) {
424
+ if ((device.find (" NPU" ) != std::string::npos) || (device.find (" GPU" ) != std::string::npos)) {
425
+ prefill_use_full_chat_history = true ;
426
+ }
427
+ }
428
+
429
+ void StatefulOVInferRequest::PreProcessInferRequest () {
423
430
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
424
431
// as a workaround.
425
432
// TODO: Fix this.
426
433
ov::Tensor beam_idx = ov::Tensor (ov::element::i32, {1 });
427
434
std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
428
435
ovInfReq.set_tensor (" beam_idx" , beam_idx);
429
436
430
- // For NPU, we need to cache input_ids and position_ids for
431
- // chat-mode support.
432
- if (device.find (" NPU" ) != std::string::npos) {
437
+ // If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
438
+ if (prefill_use_full_chat_history) {
433
439
auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
434
440
435
441
// add input_ids to our cache
@@ -454,6 +460,9 @@ void StatefulOVInferRequest::_pre_infer() {
454
460
// if the input_ids size doesn't equal cached size of the input_ids
455
461
// then it means that we're running 2nd (or later) prompt.
456
462
if (input_ids_tensor.get_shape ()[1 ] != cached_input_ids.size ()) {
463
+ // Clear the internal KVCache state (note: this is a no-op for NPU)
464
+ ovInfReq.reset_state ();
465
+
457
466
// set a new input_ids tensor with the content of our cached input_ids
458
467
{
459
468
auto new_shape = input_ids_tensor.get_shape ();
@@ -480,19 +489,22 @@ void StatefulOVInferRequest::_pre_infer() {
480
489
}
481
490
482
491
void StatefulOVInferRequest::StartAsync () {
483
- _pre_infer ();
492
+ PreProcessInferRequest ();
484
493
OVInferRequest::StartAsync ();
485
494
}
486
495
487
496
void StatefulOVInferRequest::Infer () {
488
- _pre_infer ();
497
+ PreProcessInferRequest ();
489
498
OVInferRequest::Infer ();
490
499
}
491
500
492
501
void StatefulOVInferRequest::RewindKVCache (size_t index) {
493
- if (device == " NPU" ) {
494
- std::cout << " RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
495
- << index << std::endl;
502
+ LOGS_DEFAULT (INFO) << log_tag << " RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index << std::endl;
503
+
504
+ if (prefill_use_full_chat_history) {
505
+ // Clear the internal KVCache state (note: this is a no-op for NPU)
506
+ ovInfReq.reset_state ();
507
+
496
508
if (cached_input_ids.size () > index ) {
497
509
cached_input_ids.resize (index );
498
510
}
@@ -501,8 +513,6 @@ void StatefulOVInferRequest::RewindKVCache(size_t index) {
501
513
cached_position_ids.resize (index );
502
514
}
503
515
} else {
504
- std::cout << " OVInferRequest::RewindKVCache: Trimming internal states to length = "
505
- << index << std::endl;
506
516
if (index == 0 ) {
507
517
// in this case, since we're trimming *all* of the KVCache, just reset the state.
508
518
ovInfReq.reset_state ();
0 commit comments