Skip to content

Compile API: support for OrtModel input and write output to stream #24740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,22 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e
_Out_ size_t* num_selected,
_In_ void* state);

/** \brief Function that writes a buffer to a stream.
*
* \param steam_state Opaque pointer holding the state for the user's stream.
* \param buffer The buffer to write to the stream.
* \param buffer_num_bytes The size of the buffer in bytes.
* \param num_bytes_written Output parameter that should be set to the number of bytes written to the stream.
*
* \return OrtStatus* Write status. Return nullptr on success.
* Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
* ORT will release the OrtStatus* if not null.
*/
typedef OrtStatus*(ORT_API_CALL* OrtOutStreamWriteFunc)(_In_ void* stream_state,
_In_ const void* buffer,
_In_ size_t buffer_num_bytes,
_Out_ size_t* num_bytes_written);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need num_bytes_written? doesn't that complicate things vs. requiring the function implementer to handle all bytes they were given?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Removed.


/** \brief Algorithm to use for cuDNN Convolution Op
*/
typedef enum OrtCudnnConvAlgoSearch {
Expand Down Expand Up @@ -5831,8 +5847,9 @@ struct OrtCompileApi {

/** \brief Sets the file path to the input ONNX model to compile.
*
* The input model's location (e.g., file path or memory buffer) must be set with either
* ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer.
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error
* code ORT_INVALID_ARGUMENT.
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise).
Expand All @@ -5846,8 +5863,9 @@ struct OrtCompileApi {

/** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile.
*
* The input model's location (e.g., file path or memory buffer) must be set with either
* ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer.
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error
* code ORT_INVALID_ARGUMENT.
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] input_model_data Buffer containing the loaded ONNX model bytes.
Expand All @@ -5864,10 +5882,10 @@ struct OrtCompileApi {

/** \brief Sets the file path for the output ONNX model generated by CompileModel.
*
* The output model's location (e.g., file path or memory buffer) can be set with either
* ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer.
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions
* that begin with ModelCompilationOptions_SetOutputModel____.
*
* If the output model's location is not set, ONNX Runtime will generate an output file with a path based on
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on
* the input model's file path. Examples:
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx
* /Path/my_model -> /Path/my_model_ctx.onnx
Expand Down Expand Up @@ -5906,10 +5924,10 @@ struct OrtCompileApi {
*
* The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer.
*
* The output model's location (e.g., file path or memory buffer) can be set with either
* ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer.
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions
* that begin with ModelCompilationOptions_SetOutputModel____.
*
* If the output model's location is not set, ONNX Runtime will generate an output file with a path based on
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on
* the input model's file path. Examples:
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx
* /Path/my_model -> /Path/my_model_ctx.onnx
Expand Down Expand Up @@ -5964,6 +5982,48 @@ struct OrtCompileApi {
* \since Version 1.22.
*/
ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options);

/** \brief Sets the input OrtModel instance to compile.
*
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error
* code ORT_INVALID_ARGUMENT.
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] input_model The OrtModel instance to compile.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options,
_In_ const OrtModel* input_model);
Comment on lines +5997 to +5998
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with both variants of OrtModel? One completely created with the model editor API and one that augments and existing model? I think it's fine to only support the former.


/** \brief Sets an output stream used to write out the output model's serialized ONNX bytes.
*
* The provided write function is called repeatedly until then entire output model has been written out. Each call to
* the write function must write at least one byte, otherwise CompileModel() will return an OrtStatus with error code
* ORT_FAIL.
*
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions
* that begin with ModelCompilationOptions_SetOutputModel____.
*
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on
* the input model's file path. Examples:
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx
* /Path/my_model -> /Path/my_model_ctx.onnx
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] write_stream_func The OrtOutStreamWriteFunc function to call when writing out the model.
* \param[in] state Opaque stream state passed as the first argument to OrtOutStreamWriteFunc. Can be null.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelOutStream,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state);
};

ORT_RUNTIME_CLASS(Ep);
Expand Down
67 changes: 35 additions & 32 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1138,38 +1138,6 @@ struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
};

/** \brief Options object used when compiling a model.
*
* Wraps ::OrtModelCompilationOptions object and methods
*/
struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
using Base = detail::Base<OrtModelCompilationOptions>;
using Base::Base;

explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used.

ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions
ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions

ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath
ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data,
size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer
ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode
ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath
ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path,
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
};

/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
*
* \param env: ORT environment object.
* \param model_compilation_options: Compilation options for a model.
* \return A Status indicating success or failure.
*/
Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);

/** \brief Wrapper around ::OrtModelMetadata
*
*/
Expand Down Expand Up @@ -2873,5 +2841,40 @@ struct Model : detail::ModelImpl<OrtModel> {

ConstModel GetConst() const { return ConstModel{this->p_}; }
};

/** \brief Options object used when compiling a model.
*
* Wraps ::OrtModelCompilationOptions object and methods
*/
struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
using Base = detail::Base<OrtModelCompilationOptions>;
using Base::Base;

explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used.

ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions
ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions

ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath
ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data,
size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer
ModelCompilationOptions& SetInputModel(ConstModel input_model); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModel
ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode
ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath
ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path,
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
ModelCompilationOptions& SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelOutStream
};

/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
*
* \param env: ORT environment object.
* \param model_compilation_options: Compilation options for a model.
* \return A Status indicating success or failure.
*/
Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);

} // namespace Ort
#include "onnxruntime_cxx_inline.h"
14 changes: 14 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(
ConstModel input_model) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, input_model));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath(
const ORTCHAR_T* output_model_path) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path));
Expand All @@ -824,6 +830,14 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelOutStream(
OrtOutStreamWriteFunc write_stream_func, void* stream_state) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelOutStream(this->p_,
write_stream_func,
stream_state));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
bool embed_ep_context_in_model) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
Expand Down
134 changes: 134 additions & 0 deletions onnxruntime/core/framework/ep_context_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cassert>
#include <limits>
#include <string>
#include <utility>
#include "core/common/common.h"
#include "core/framework/ep_context_options.h"
#include "core/framework/error_code_helper.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

namespace onnxruntime {
namespace epctx {
// class ModelGenOptions

ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) {
enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";

std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
if (!output_model_path.empty()) {
output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
} else {
output_model_location = std::monostate{};
}

output_external_initializers_file_path = config_options.GetConfigOrDefault(
kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
output_external_initializer_size_threshold = 0;
embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
}

bool ModelGenOptions::HasOutputModelLocation() const {
return !std::holds_alternative<std::monostate>(output_model_location);
}

const std::string* ModelGenOptions::TryGetOutputModelPath() const {
return std::get_if<std::string>(&output_model_location);
}

const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const {
return std::get_if<BufferHolder>(&output_model_location);
}

const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const {
return std::get_if<OutStreamHolder>(&output_model_location);
}

// class OutStreamBuf

OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) {
setp(buffer_.data(), buffer_.data() + buffer_.size());
}

OutStreamBuf::~OutStreamBuf() {
sync();
}

// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_.
std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) {
if (sync() == -1) {
return traits_type::eof();
}

if (ch != traits_type::eof()) {
*pptr() = static_cast<char>(ch);
pbump(1);
}

return ch;
}

// Flushes the entire buffer_ to the user's write function.
int OutStreamBuf::sync() {
if (!last_status_.IsOK()) {
return -1;
}

std::ptrdiff_t num_bytes = pptr() - pbase();
if (num_bytes == 0) {
return 0;
}

// Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes.
if (num_bytes > std::numeric_limits<int>::max()) {
num_bytes = std::numeric_limits<int>::max();
}

std::ptrdiff_t bytes_remaining = num_bytes;
char* ptr = pbase();

while (bytes_remaining > 0) {
size_t bytes_written = 0;
Status status = Status::OK();

ORT_TRY {
status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state,
ptr, bytes_remaining, &bytes_written));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what());
});
}

if (!status.IsOK()) {
last_status_ = std::move(status);
return -1;
}

if (bytes_written > static_cast<size_t>(bytes_remaining)) {
last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc wrote more bytes (", bytes_written,
") than requested (", bytes_remaining, ").");
return -1;
}

if (bytes_written == 0) {
last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc failed to write any data. ",
"Stopping write attempts to avoid a potential infinite loop.");
return -1;
}

bytes_remaining -= static_cast<std::ptrdiff_t>(bytes_written);
ptr += bytes_written;
}

assert(ptr == pptr());
pbump(-static_cast<int>(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_
return 0;
}

} // namespace epctx
} // namespace onnxruntime
Loading
Loading