Skip to content

Commit 761d911

Browse files
committed
feat: load_config will now accept MAX_PROMPT_LEN & MIN_RESPONSE_LEN with enable_causallm=True (#660)
1 parent a7c40d0 commit 761d911

File tree

3 files changed

+47
-15
lines changed

3 files changed

+47
-15
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/providers/openvino/backends/basic_backend.h"
1616
#include "core/providers/openvino/onnx_ctx_model_helper.h"
1717
#include "core/providers/openvino/backend_manager.h"
18+
#include "core/providers/openvino/ov_stateful_patch_utils.h"
1819

1920
namespace onnxruntime {
2021

@@ -200,6 +201,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
200201
if (!session_context_.load_config.empty()) {
201202
const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config;
202203

204+
if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) {
205+
if (target_config.find("NPU") != target_config.end()) {
206+
auto npu_genai_config = target_config.at("NPU");
207+
CausalLMConfig().ApplyConfig(npu_genai_config, device_config);
208+
} else {
209+
LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found.";
210+
}
211+
}
212+
203213
if (session_context_.device_type.find("NPU") != std::string::npos) {
204214
auto npuw_config = target_config.at("NPU");
205215

@@ -265,7 +275,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
265275
auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
266276
const std::vector<ov::PropertyName>& supported_properties) {
267277
for (const auto& [key, value] : config_options) {
268-
if (key.find("NPUW") != std::string::npos) {
278+
if ((key.find("NPUW") != std::string::npos) ||
279+
((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) {
269280
continue;
270281
}
271282
if (is_supported_and_mutable(key, supported_properties)) {

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
#include "core/session/onnxruntime_cxx_api.h"
88
#include "core/providers/shared_library/provider_api.h"
99
#include "core/providers/openvino/backend_utils.h"
10-
11-
// for make stateful utility function(s)
10+
#include "core/providers/openvino/backends/basic_backend.h"
1211
#include "core/providers/openvino/ov_stateful_patch_utils.h"
1312

1413
using Exception = ov::Exception;
@@ -97,9 +96,9 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
9796
}
9897

9998
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
100-
bool status = IsStateful(model);
101-
std::cout << "IsStateful Status:\t" << status << std::endl;
102-
if (!status) {
99+
bool model_status = IsStateful(model);
100+
LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False");
101+
if (!model_status) {
103102
PatchStatefulDecoder(model);
104103
}
105104

@@ -109,17 +108,25 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
109108
}
110109

111110
auto kv_pos = GetKVAxesPos(model);
112-
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
113-
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
114-
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
115-
}
116111

117112
if (hw_target.find("NPU") != std::string::npos) {
118113
KVDesc kv_desc;
119-
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(1024u);
120-
kv_desc.min_response_len = PopIntAndCast(config, "MIN_RESPONSE_LEN").value_or(128u);
114+
auto parse_genai_config = [&](const std::string& key, unsigned int default_value) {
115+
return (config.count(key) && !config.at(key).empty() && config.at(key).as<std::string>() != "0") ?
116+
config.at(key).as<unsigned int>() : default_value;
117+
};
118+
119+
kv_desc.max_prompt_len = parse_genai_config("MAX_PROMPT_LEN", CausalLMConfig().max_prompt_len);
120+
kv_desc.min_response_len = parse_genai_config("MIN_RESPONSE_LEN", CausalLMConfig().min_response_len);
121+
122+
// For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0
123+
if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0) {
124+
ORT_THROW(log_tag + "MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty");
125+
}
121126

122127
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
128+
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
129+
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
123130
std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl;
124131
std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl;
125132
}
@@ -132,10 +139,8 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
132139
ApplySliceBeforeMatmulTransformation(model);
133140
}
134141

135-
std::cout << "Compiling Stateful OV Model ..." << std::endl;
142+
LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow";
136143
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
137-
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
138-
139144
OVExeNetwork exe(compiled_model, hw_target, true);
140145
return exe;
141146
}

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ struct KVDesc {
5757
uint32_t min_response_len;
5858
};
5959

60+
struct CausalLMConfig {
61+
void ApplyConfig(const ov::AnyMap& external_config, ov::AnyMap& genai_config) {
62+
if (external_config.find("MAX_PROMPT_LEN") != external_config.end()) {
63+
max_prompt_len = external_config.at("MAX_PROMPT_LEN").as<unsigned int>();
64+
}
65+
if (external_config.find("MIN_RESPONSE_LEN") != external_config.end()) {
66+
min_response_len = external_config.at("MIN_RESPONSE_LEN").as<unsigned int>();
67+
}
68+
genai_config["MAX_PROMPT_LEN"] = ov::Any(max_prompt_len);
69+
genai_config["MIN_RESPONSE_LEN"] = ov::Any(min_response_len);
70+
}
71+
72+
unsigned int max_prompt_len = 1024;
73+
unsigned int min_response_len = 128;
74+
};
75+
6076
void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVDesc& kv_desc);
6177

6278
std::optional<ov::Any> PopOptionNew(ov::AnyMap& config, const std::string& option_name);

0 commit comments

Comments
 (0)