Skip to content

Commit

Permalink
Add SetEpDynamicOptions and remove workload_type from run/session opt…
Browse files Browse the repository at this point in the history
…ions (microsoft#22282)

### Description
Add SetEpDynamicOptions and Remove workload_type from run/session
options.

### Motivation and Context
Added SetEpDynamicOptions as a dynamic way of changing EP settings even
in the middle of a Run
Using workload_type run/session options to set Efficient/Default mode
for workloads does not cover all the scenarios and can lead to priority
inversions. Working on a new API to support setting Efficient/Default
mode for workloads.

---------

Co-authored-by: Luis E. Pena <[email protected]>
  • Loading branch information
l31g and Luis E. Pena authored Oct 10, 2024
1 parent 2bef89c commit 1bc546a
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 6 deletions.
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ class IExecutionProvider {
return Status::OK();
}

/**
Called when InferenceSession::SetEpDynamicOptions is called
*/
virtual common::Status SetEpDynamicOptions(gsl::span<const char* const> /*keys*/,
gsl::span<const char* const> /*values*/) {
return Status::OK();
}

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
the provider.
Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4722,6 +4722,25 @@ struct OrtApi {
* \param[in] adapter OrtLoraAdapter instance
*/
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);

/// @}
/// \name OrtEpDynamicOptions
/// @{

/** \brief Set DynamicOptions for EPs (Execution Providers)
*
* Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h`
* Look for `kOrtEpDynamicOptions`
*
* \param[in] session
* \param[in] list of keys represented by null-terminated strings
* \param[in] list of values represented by null-terminated strings
* \param[in] number of key-value pairs
*
* \since Version 1.20
*/
ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);
};

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,3 @@ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_con
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
// User are not expected to set the value to 0 as it is reserved for internal use.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";

// Specify the type of workload for this run.
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtRunOptionsWorkloadType = "run.workload_type";
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";

// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
// Meant to be used with SetEpDynamicOptions
// Specify the type of workload for this session.
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtSessionOptionsWorkloadType = "session.workload_type";
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
18 changes: 18 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,24 @@ struct ThreadPoolSpinningSwitch {
};
} // namespace

Status InferenceSession::SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> values) {
Status retval = Status::OK();

if (!is_inited_) {
LOGS(*session_logger_, ERROR) << "Session was not initialized";
return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
}

// TODO: only call SetEpDynamicOptions for all providers in-use
for (auto& xp : execution_providers_) {
auto status = xp->SetEpDynamicOptions(keys, values);
ORT_CHECK_AND_SET_RETVAL(status);
}

return retval;
}

Status InferenceSession::Run(const RunOptions& run_options,
gsl::span<const std::string> feed_names, gsl::span<const OrtValue> feeds,
gsl::span<const std::string> output_names, std::vector<OrtValue>* p_fetches,
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ class InferenceSession {
*/
[[nodiscard]] common::Status Initialize();

[[nodiscard]] common::Status SetEpDynamicOptions(gsl::span<const char* const> keys,
gsl::span<const char* const> values);

[[nodiscard]] common::Status Run(const RunOptions& run_options, gsl::span<const std::string> feed_names,
gsl::span<const OrtValue> feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches,
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,28 @@ void CheckAndAdjustInputSpansForLora(const OrtRunOptions& run_options,

} // namespace

ORT_API_STATUS_IMPL(OrtApis::SetEpDynamicOptions, _Inout_ OrtSession* sess,
_In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values,
_In_ size_t kv_len) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

auto keys_span = gsl::make_span(keys, kv_len);
auto values_span = gsl::make_span(values, kv_len);

Status status;

if (kv_len == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "no imputs were passed");
} else {
status = session->SetEpDynamicOptions(keys_span,
values_span);
}
return ToOrtStatus(status);
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
Expand Down Expand Up @@ -2785,6 +2807,8 @@ static constexpr OrtApi ort_api_1_to_20 = {
&OrtApis::CreateLoraAdapterFromArray,
&OrtApis::ReleaseLoraAdapter,
&OrtApis::RunOptionsAddActiveLoraAdapter,

&OrtApis::SetEpDynamicOptions,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,6 @@ ORT_API_STATUS_IMPL(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t n
ORT_API(void, ReleaseLoraAdapter, _Frees_ptr_opt_ OrtLoraAdapter*);
ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);

ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);
} // namespace OrtApis

0 comments on commit 1bc546a

Please sign in to comment.