Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,5 +781,11 @@
}
}

void BackendManager::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {

Check warning on line 784 in onnxruntime/core/providers/openvino/backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/backend_manager.cc:784: Add #include <vector> for vector<> [build/include_what_you_use] [4]
if (concrete_backend_) {
concrete_backend_->ReorderKVCache(src_indices, dst_indices);
}
}

} // namespace openvino_ep
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BackendManager {
void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data);
ov::CompiledModel GetOVCompiledModel();
void RewindKVCache(size_t index);
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices);

private:
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@
});
}

void BasicBackend::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {

Check warning on line 318 in onnxruntime/core/providers/openvino/backends/basic_backend.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/backends/basic_backend.cc:318: Add #include <vector> for vector<> [build/include_what_you_use] [4]
infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) {
infer_request->ReorderKVCache(src_indices, dst_indices);
});
}

void BasicBackend::Infer(OrtKernelContext* ctx) const {
Ort::KernelContext context(ctx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class BasicBackend : public IBackend {
return exe_network_.Get();
}
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;

private:
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/ibackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
virtual ov::CompiledModel GetOVCompiledModel() = 0;
virtual ~IBackend() = default;
virtual void RewindKVCache(size_t index) {}
virtual void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {}

Check warning on line 21 in onnxruntime/core/providers/openvino/ibackend.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ibackend.h:21: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};
using ptr_stream_t = std::unique_ptr<ModelBlobWrapper>;
class BackendFactory {
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/openvino/openvino_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <memory>
#include <vector>
#include <cerrno>
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/openvino/openvino_execution_provider.h"
#include "core/providers/openvino/contexts.h"
Expand Down Expand Up @@ -286,6 +287,64 @@
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
}
}
} else if (key == "kvcache_reorder") {
// Convert kvcache_reorder value format "1,2,3;4,5,6" into two vectors
// src_indices = [1,2,3], dst_indices = [4,5,6]
size_t delimiter_pos = value.find(';');
if (delimiter_pos == std::string::npos) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder value format is incorrect, expected format is 'x1,x2,x3;y1,y2,y3' where x and y are comma-separated int64_t lists");
}

std::string_view src_string(value.begin(), value.begin() + delimiter_pos);
std::string_view dst_string(value.begin() + delimiter_pos + 1, value.end());

constexpr auto parse_indices = [](std::string_view input, const std::string& index_type) -> std::variant<Status, std::vector<int32_t>> {
std::vector<int32_t> indices;
while (!input.empty()) {
const auto delimiter_pos = input.find(',');
const auto part = input.substr(0, delimiter_pos);
errno = 0;
char* parse_end = nullptr;
// strtol/stol already skips whitespaces
const auto index = std::strtol(part.data(), &parse_end, 10);
if (parse_end == part.data()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Failed to parse kvcache_reorder " + index_type + ": " + std::string(part));
}
if (index < 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder " + index_type + " cannot be negative: " + std::string(part));
}
if (index > std::numeric_limits<int32_t>::max() || errno == ERANGE) {

Check warning on line 319 in onnxruntime/core/providers/openvino/openvino_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/openvino_execution_provider.cc:319: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part));
}
indices.push_back(static_cast<int32_t>(index));
if (delimiter_pos != std::string_view::npos) {
// ignore any trailing chars after the number, can do further checking if needed
input.remove_prefix(part.size() + 1);
} else {
break;
}
}
return indices;
};

const auto src_indices = parse_indices(src_string, "src_index");
if (std::holds_alternative<Status>(src_indices)) {
return std::get<Status>(src_indices);
}

const auto dst_indices = parse_indices(dst_string, "dst_index");
if (std::holds_alternative<Status>(dst_indices)) {
return std::get<Status>(dst_indices);
}

// Trigger KVCache Reorder for target Backend with vector arguments
for (auto& backend : backend_managers_) {
backend.ReorderKVCache(std::get<std::vector<int32_t>>(src_indices), std::get<std::vector<int32_t>>(dst_indices));
}
} else {
// Handle unknown options
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;
Expand Down
63 changes: 58 additions & 5 deletions onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@

bool model_status = IsStateful(model);
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
// Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding
bool should_add_kvcache_reorder = false;
if (!model_status) {
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
PatchStatefulDecoder(model);
// TO-DO: extend to NPU device when OpenVINO NPU has related optimization
should_add_kvcache_reorder = hw_target.find("GPU") != std::string::npos;
PatchStatefulDecoder(model, should_add_kvcache_reorder);
}

if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
Expand Down Expand Up @@ -152,7 +156,7 @@

LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow";
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
OVExeNetwork exe(compiled_model, hw_target, true);
OVExeNetwork exe(compiled_model, hw_target, true, should_add_kvcache_reorder);
return exe;
}

Expand Down Expand Up @@ -332,7 +336,7 @@
auto infReq = compiled_model_obj.create_infer_request();
std::shared_ptr<OVInferRequest> ovInfReq;
if (is_stateful_causallm) {
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device);
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), target_device, is_kvcache_reorder_added);
} else {
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
}
Expand Down Expand Up @@ -377,8 +381,8 @@
"In Error Couldn't start Inference");
}

StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
: OVInferRequest(std::move(infer_request)), target_device(device) {
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool kvcache_reorder_added)
: OVInferRequest(std::move(infer_request)), target_device(device), is_kvcache_reorder_added(kvcache_reorder_added) {

Check warning on line 385 in onnxruntime/core/providers/openvino/ov_interface.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_interface.cc:385: Add #include <utility> for move [build/include_what_you_use] [4]
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));

_npu_logits_slice_required = IsNPULogitsSliceRequired();
Expand Down Expand Up @@ -469,6 +473,32 @@
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
FillTensor("beam_idx", ov::element::i32, {1}, 0);

if (is_kvcache_reorder_added){
ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape();
const auto kv_num_heads = dst_idx_shape[1];
const auto kv_head_size = dst_idx_shape[3];
if (kv_src_indices.size() > 0) {
ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()});
const auto src_idx_ptr = src_idx_tensor.data<int32_t>();
for (size_t i = 0; i < kv_src_indices.size(); ++i) {
src_idx_ptr[i] = static_cast<int32_t>(kv_src_indices[i]);
}
ovInfReq.set_tensor("src_idx", src_idx_tensor);

ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size});
const auto dst_idx_ptr = dst_idx_tensor.data<int32_t>();
for (size_t i = 0; i < kv_num_heads; ++i) {
for (size_t j = 0; j < kv_dst_indices.size(); ++j) {
std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]);
}
}
ovInfReq.set_tensor("dst_idx", dst_idx_tensor);
} else {
FillTensor("src_idx", ov::element::i32, {0}, 0);
FillTensor("dst_idx", ov::element::i32, {1, kv_num_heads, 0, kv_head_size}, 0);
}
}

// If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
if (prefill_use_full_chat_history) {
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
Expand Down Expand Up @@ -505,6 +535,29 @@
void StatefulOVInferRequest::Infer() {
PreProcessInferRequest();
OVInferRequest::Infer();
PostProcessInferRequest();
}

void StatefulOVInferRequest::PostProcessInferRequest() {
if(is_kvcache_reorder_added){
kv_src_indices.clear();
kv_dst_indices.clear();
}
}

void StatefulOVInferRequest::ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) {

Check warning on line 548 in onnxruntime/core/providers/openvino/ov_interface.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_interface.cc:548: Add #include <vector> for vector<> [build/include_what_you_use] [4]
// Validate input parameters
if (src_indices.size() != dst_indices.size()) {
ORT_THROW(log_tag + "ReorderKVCache: src_indices and dst_indices must have the same size. "
"Got src_indices.size()=" + std::to_string(src_indices.size()) +
", dst_indices.size()=" + std::to_string(dst_indices.size()));
}

LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
<< src_indices.size() << " index pairs";

kv_src_indices = src_indices;
kv_dst_indices = dst_indices;
}

void StatefulOVInferRequest::RewindKVCache(size_t index) {
Expand Down
19 changes: 14 additions & 5 deletions onnxruntime/core/providers/openvino/ov_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ class OVExeNetwork {
ov::CompiledModel compiled_model_obj;
std::string target_device;
bool is_stateful_causallm;
bool is_kvcache_reorder_added = false;

public:
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false)
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {}
explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool kvcache_reorder_added = false)
: compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_kvcache_reorder_added(kvcache_reorder_added) {}
OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {}
ov::CompiledModel& Get() { return compiled_model_obj; }
std::shared_ptr<OVInferRequest> CreateInferRequest();
Expand Down Expand Up @@ -136,14 +137,16 @@ class OVInferRequest {
return ovInfReq;
}
virtual void RewindKVCache([[maybe_unused]] size_t index) {}
virtual void ReorderKVCache([[maybe_unused]] const std::vector<int32_t>& src_indices, [[maybe_unused]] const std::vector<int32_t>& dst_indices) {}
};

class StatefulOVInferRequest : public OVInferRequest {
public:
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device);
explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool kvcache_reorder_added = false);

void Infer() override;
void RewindKVCache(size_t index) override;
void ReorderKVCache(const std::vector<int32_t>& src_indices, const std::vector<int32_t>& dst_indices) override;
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
const std::vector<size_t>& shape, int32_t fill_value);
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
Expand All @@ -153,13 +156,19 @@ class StatefulOVInferRequest : public OVInferRequest {

private:
void PreProcessInferRequest();
void PostProcessInferRequest();
std::string target_device;

std::vector<int64_t> cached_input_ids;
std::vector<int64_t> cached_position_ids;
std::vector<int32_t> kv_src_indices;
std::vector<int32_t> kv_dst_indices;

// If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors,
// and ensure that full chat history is passed for each prefill call.
bool prefill_use_full_chat_history = false;
std::vector<int64_t> cached_input_ids;
std::vector<int64_t> cached_position_ids;
// If kvcache_reorder is added, will include kv_src/dst_indices as input
bool is_kvcache_reorder_added = false;

bool IsNPULogitsSliceRequired();
bool _npu_logits_slice_required = false;
Expand Down
43 changes: 39 additions & 4 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ std::string GetInputOutputName(std::shared_ptr<ov::Model> ov_model,
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::vector<std::string>& not_kv_inputs,
const std::vector<std::string>& key_value_input_names,
int gather_dim) {
int gather_dim,
const bool should_add_kvcache_reorder) {
if (ModelHasInputOutputNames(ov_model, "beam_idx")) {
throw std::runtime_error("Model already has fused cache");
}
Expand All @@ -91,13 +92,31 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);

auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape();

auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
beam_idx->set_friendly_name("beam_idx");
beam_idx->output(0).get_tensor().add_names({"beam_idx"});
ov_model->add_parameters({beam_idx});
not_kv_inputs.push_back(beam_idx->get_friendly_name());

std::shared_ptr<ov::opset13::Parameter> src_idx;
std::shared_ptr<ov::opset13::Parameter> dst_idx;

if (should_add_kvcache_reorder) {
src_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({update_shape[2]}));
src_idx->set_friendly_name("src_idx");
src_idx->output(0).get_tensor().add_names({"src_idx"});
ov_model->add_parameters({src_idx});
not_kv_inputs.push_back(src_idx->get_friendly_name());

dst_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, update_shape);
dst_idx->set_friendly_name("dst_idx");
dst_idx->output(0).get_tensor().add_names({"dst_idx"});
ov_model->add_parameters({dst_idx});
not_kv_inputs.push_back(dst_idx->get_friendly_name());
}

// Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
for (const auto& input_name : key_value_input_names) {
auto parameter_output_port = ov_model->input(input_name);
Expand All @@ -108,9 +127,25 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
beam_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim}));

std::shared_ptr<ov::Node> output_node;
if (should_add_kvcache_reorder) {
auto updatekv_gather_op =
std::make_shared<ov::opset13::Gather>(gather_op,
src_idx,
ov::opset13::Constant::create(ov::element::i64, {}, {2}));

auto updatekv_op = std::make_shared<ov::opset12::ScatterElementsUpdate>(gather_op,
dst_idx,
updatekv_gather_op,
ov::opset13::Constant::create(ov::element::i64, {}, {2}));
output_node = updatekv_op;
} else {
output_node = gather_op;
}

// Replace the source output for all consumers of the input tensor
for (auto& consumer : consumers) {
consumer.replace_source_output(gather_op->output(0));
consumer.replace_source_output(output_node->output(0));
}
}

Expand Down Expand Up @@ -247,7 +282,7 @@ std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTens
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool should_add_kvcache_reorder) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);
Expand All @@ -269,7 +304,7 @@ void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
auto batch_dim = 0;

FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim);
FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, should_add_kvcache_reorder);

MakeStateful(model, key_value_input_names, key_value_output_names);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "openvino/pass/manager.hpp"
#include "openvino/pass/make_stateful.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/opsets/opset13.hpp"

namespace onnxruntime {
Expand All @@ -25,13 +26,14 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::vector<std::string>& not_kv_inputs,
const std::vector<std::string>& key_value_input_names,
int gather_dim);
int gather_dim,
const bool should_add_kvcache_reorder = false);

void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
const std::vector<std::string>& key_value_input_names,
const std::vector<std::string>& key_value_output_names);

void PatchStatefulDecoder(std::shared_ptr<ov::Model> model);
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model, const bool should_add_kvcache_reorder = false);

bool HasOpWithType(const std::shared_ptr<const ov::Model>& function, const std::string& type_name);

Expand Down
Loading