7
7
#include " core/session/onnxruntime_cxx_api.h"
8
8
#include " core/providers/shared_library/provider_api.h"
9
9
#include " core/providers/openvino/backend_utils.h"
10
-
11
- // for make stateful utility function(s)
10
+ #include " core/providers/openvino/backends/basic_backend.h"
12
11
#include " core/providers/openvino/ov_stateful_patch_utils.h"
13
12
14
13
using Exception = ov::Exception;
@@ -97,9 +96,9 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
97
96
}
98
97
99
98
LOGS_DEFAULT (INFO) << log_tag << " Converting from Stateless OV Model to Stateful OV Model" << std::endl;
100
- bool status = IsStateful (model);
101
- std::cout << " IsStateful Status:\t " << status << std::endl ;
102
- if (!status ) {
99
+ bool model_status = IsStateful (model);
100
+ LOGS_DEFAULT (INFO) << log_tag << " Model IsStateful() Status:\t " << (model_status ? " True " : " False " ) ;
101
+ if (!model_status ) {
103
102
PatchStatefulDecoder (model);
104
103
}
105
104
@@ -109,17 +108,25 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
109
108
}
110
109
111
110
auto kv_pos = GetKVAxesPos (model);
112
- if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled ()) {
113
- std::cout << " kv_pos.batch = " << kv_pos.batch << std::endl;
114
- std::cout << " kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
115
- }
116
111
117
112
if (hw_target.find (" NPU" ) != std::string::npos) {
118
113
KVDesc kv_desc;
119
- kv_desc.max_prompt_len = PopIntAndCast (config, " MAX_PROMPT_LEN" ).value_or (1024u );
120
- kv_desc.min_response_len = PopIntAndCast (config, " MIN_RESPONSE_LEN" ).value_or (128u );
114
+ auto parse_genai_config = [&](const std::string& key, unsigned int default_value) {
115
+ return (config.count (key) && !config.at (key).empty () && config.at (key).as <std::string>() != " 0" ) ?
116
+ config.at (key).as <unsigned int >() : default_value;
117
+ };
118
+
119
+ kv_desc.max_prompt_len = parse_genai_config (" MAX_PROMPT_LEN" , CausalLMConfig ().max_prompt_len );
120
+ kv_desc.min_response_len = parse_genai_config (" MIN_RESPONSE_LEN" , CausalLMConfig ().min_response_len );
121
+
122
+ // For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0
123
+ if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0 ) {
124
+ ORT_THROW (log_tag + " MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty" );
125
+ }
121
126
122
127
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled ()) {
128
+ std::cout << " kv_pos.batch = " << kv_pos.batch << std::endl;
129
+ std::cout << " kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
123
130
std::cout << " kv_desc.max_prompt_len:\t " << kv_desc.max_prompt_len << std::endl;
124
131
std::cout << " kv_desc.min_response_len:\t " << kv_desc.min_response_len << std::endl;
125
132
}
@@ -132,10 +139,8 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
132
139
ApplySliceBeforeMatmulTransformation (model);
133
140
}
134
141
135
- std::cout << " Compiling Stateful OV Model ... " << std::endl ;
142
+ LOGS_DEFAULT (INFO) << log_tag << " Compiling OV Model using Stateful Transformation flow " ;
136
143
compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
137
- std::cout << " Stateful OV Model Compilation Complete" << std::endl;
138
-
139
144
OVExeNetwork exe (compiled_model, hw_target, true );
140
145
return exe;
141
146
}
0 commit comments