-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Compile API: support for OrtModel input and write output to stream #24740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adrianlizarraga
wants to merge
14
commits into
main
Choose a base branch
from
adrianl/CompileApi_OrtModelInput_StreamWriteOutput
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+849
−218
Open
Changes from 12 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
769014f
Compile API: add suport for OrtModel input
adrianlizarraga ffcfab9
Update onnxruntime/core/session/compile_api.cc
adrianlizarraga 025474a
Fix tensorprotoutils when reading ext data from a memory buffer
adrianlizarraga 94596eb
Clean up
adrianlizarraga 134a702
Add C API to write compiled model to stream
adrianlizarraga 7f719a9
Serialize proto to stream. need test
adrianlizarraga 72eeece
Add unit test for compiling to user's write stream
adrianlizarraga 2043a36
Clean up
adrianlizarraga 1ee829f
Handle case where user's write function never writes data (return error)
adrianlizarraga a21be47
Reduce repetition in unit tests
adrianlizarraga 78df685
Add Python bindings for compiling to a stream
adrianlizarraga 09a90a1
fix min build unused var
adrianlizarraga d12f435
Flush before checking status of write
adrianlizarraga e5c27c3
Simplify: expect write function to write entire provided buffer
adrianlizarraga File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -453,6 +453,22 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e | |
_Out_ size_t* num_selected, | ||
_In_ void* state); | ||
|
||
/** \brief Function that writes a buffer to a stream. | ||
* | ||
* \param steam_state Opaque pointer holding the state for the user's stream. | ||
* \param buffer The buffer to write to the stream. | ||
* \param buffer_num_bytes The size of the buffer in bytes. | ||
* \param num_bytes_written Output parameter that should be set to the number of bytes written to the stream. | ||
* | ||
* \return OrtStatus* Write status. Return nullptr on success. | ||
* Use CreateStatus to provide error info. Use ORT_FAIL as the error code. | ||
* ORT will release the OrtStatus* if not null. | ||
*/ | ||
typedef OrtStatus*(ORT_API_CALL* OrtOutStreamWriteFunc)(_In_ void* stream_state, | ||
_In_ const void* buffer, | ||
_In_ size_t buffer_num_bytes, | ||
_Out_ size_t* num_bytes_written); | ||
|
||
/** \brief Algorithm to use for cuDNN Convolution Op | ||
*/ | ||
typedef enum OrtCudnnConvAlgoSearch { | ||
|
@@ -5831,8 +5847,9 @@ struct OrtCompileApi { | |
|
||
/** \brief Sets the file path to the input ONNX model to compile. | ||
* | ||
* The input model's location (e.g., file path or memory buffer) must be set with either | ||
* ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. | ||
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that | ||
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error | ||
* code ORT_INVALID_ARGUMENT. | ||
* | ||
* \param[in] model_compile_options The OrtModelCompilationOptions instance. | ||
* \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). | ||
|
@@ -5846,8 +5863,9 @@ struct OrtCompileApi { | |
|
||
/** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. | ||
* | ||
* The input model's location (e.g., file path or memory buffer) must be set with either | ||
* ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. | ||
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that | ||
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error | ||
* code ORT_INVALID_ARGUMENT. | ||
* | ||
* \param[in] model_compile_options The OrtModelCompilationOptions instance. | ||
* \param[in] input_model_data Buffer containing the loaded ONNX model bytes. | ||
|
@@ -5864,10 +5882,10 @@ struct OrtCompileApi { | |
|
||
/** \brief Sets the file path for the output ONNX model generated by CompileModel. | ||
* | ||
* The output model's location (e.g., file path or memory buffer) can be set with either | ||
* ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. | ||
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions | ||
* that begin with ModelCompilationOptions_SetOutputModel____. | ||
* | ||
* If the output model's location is not set, ONNX Runtime will generate an output file with a path based on | ||
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on | ||
* the input model's file path. Examples: | ||
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx | ||
* /Path/my_model -> /Path/my_model_ctx.onnx | ||
|
@@ -5906,10 +5924,10 @@ struct OrtCompileApi { | |
* | ||
* The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. | ||
* | ||
* The output model's location (e.g., file path or memory buffer) can be set with either | ||
* ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. | ||
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions | ||
* that begin with ModelCompilationOptions_SetOutputModel____. | ||
* | ||
* If the output model's location is not set, ONNX Runtime will generate an output file with a path based on | ||
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on | ||
* the input model's file path. Examples: | ||
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx | ||
* /Path/my_model -> /Path/my_model_ctx.onnx | ||
|
@@ -5964,6 +5982,48 @@ struct OrtCompileApi { | |
* \since Version 1.22. | ||
*/ | ||
ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); | ||
|
||
/** \brief Sets the input OrtModel instance to compile. | ||
* | ||
* The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that | ||
* begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error | ||
* code ORT_INVALID_ARGUMENT. | ||
* | ||
* \param[in] model_compile_options The OrtModelCompilationOptions instance. | ||
* \param[in] input_model The OrtModel instance to compile. | ||
* | ||
* \snippet{doc} snippets.dox OrtStatus Return Value | ||
* | ||
* \since Version 1.23. | ||
*/ | ||
ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, | ||
_In_ const OrtModel* input_model); | ||
Comment on lines
+5997
to
+5998
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work with both variants of OrtModel? One completely created with the model editor API and one that augments and existing model? I think it's fine to only support the former. |
||
|
||
/** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. | ||
* | ||
* The provided write function is called repeatedly until then entire output model has been written out. Each call to | ||
* the write function must write at least one byte, otherwise CompileModel() will return an OrtStatus with error code | ||
* ORT_FAIL. | ||
* | ||
* The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions | ||
* that begin with ModelCompilationOptions_SetOutputModel____. | ||
* | ||
* If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on | ||
* the input model's file path. Examples: | ||
* /Path/my_model.onnx -> /Path/my_model_ctx.onnx | ||
* /Path/my_model -> /Path/my_model_ctx.onnx | ||
* | ||
* \param[in] model_compile_options The OrtModelCompilationOptions instance. | ||
* \param[in] write_stream_func The OrtOutStreamWriteFunc function to call when writing out the model. | ||
* \param[in] state Opaque stream state passed as the first argument to OrtOutStreamWriteFunc. Can be null. | ||
* | ||
* \snippet{doc} snippets.dox OrtStatus Return Value | ||
* | ||
* \since Version 1.23. | ||
*/ | ||
ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelOutStream, | ||
_In_ OrtModelCompilationOptions* model_compile_options, | ||
_In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state); | ||
}; | ||
|
||
ORT_RUNTIME_CLASS(Ep); | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <cassert> | ||
#include <limits> | ||
#include <string> | ||
#include <utility> | ||
#include "core/common/common.h" | ||
#include "core/framework/ep_context_options.h" | ||
#include "core/framework/error_code_helper.h" | ||
#include "core/session/onnxruntime_session_options_config_keys.h" | ||
|
||
namespace onnxruntime { | ||
namespace epctx { | ||
// class ModelGenOptions | ||
|
||
ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { | ||
enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; | ||
|
||
std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); | ||
if (!output_model_path.empty()) { | ||
output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); | ||
} else { | ||
output_model_location = std::monostate{}; | ||
} | ||
|
||
output_external_initializers_file_path = config_options.GetConfigOrDefault( | ||
kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); | ||
output_external_initializer_size_threshold = 0; | ||
embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; | ||
} | ||
|
||
bool ModelGenOptions::HasOutputModelLocation() const { | ||
return !std::holds_alternative<std::monostate>(output_model_location); | ||
} | ||
|
||
const std::string* ModelGenOptions::TryGetOutputModelPath() const { | ||
return std::get_if<std::string>(&output_model_location); | ||
} | ||
|
||
const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { | ||
return std::get_if<BufferHolder>(&output_model_location); | ||
} | ||
|
||
const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const { | ||
return std::get_if<OutStreamHolder>(&output_model_location); | ||
} | ||
|
||
// class OutStreamBuf | ||
|
||
OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) { | ||
setp(buffer_.data(), buffer_.data() + buffer_.size()); | ||
} | ||
|
||
OutStreamBuf::~OutStreamBuf() { | ||
sync(); | ||
} | ||
|
||
// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. | ||
std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { | ||
if (sync() == -1) { | ||
return traits_type::eof(); | ||
} | ||
|
||
if (ch != traits_type::eof()) { | ||
*pptr() = static_cast<char>(ch); | ||
pbump(1); | ||
} | ||
|
||
return ch; | ||
} | ||
|
||
// Flushes the entire buffer_ to the user's write function. | ||
int OutStreamBuf::sync() { | ||
if (!last_status_.IsOK()) { | ||
return -1; | ||
} | ||
|
||
std::ptrdiff_t num_bytes = pptr() - pbase(); | ||
if (num_bytes == 0) { | ||
return 0; | ||
} | ||
|
||
// Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. | ||
if (num_bytes > std::numeric_limits<int>::max()) { | ||
num_bytes = std::numeric_limits<int>::max(); | ||
} | ||
|
||
std::ptrdiff_t bytes_remaining = num_bytes; | ||
char* ptr = pbase(); | ||
|
||
while (bytes_remaining > 0) { | ||
size_t bytes_written = 0; | ||
Status status = Status::OK(); | ||
|
||
ORT_TRY { | ||
status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state, | ||
ptr, bytes_remaining, &bytes_written)); | ||
} | ||
ORT_CATCH(const std::exception& e) { | ||
ORT_HANDLE_EXCEPTION([&]() { | ||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, | ||
"Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); | ||
}); | ||
} | ||
|
||
if (!status.IsOK()) { | ||
last_status_ = std::move(status); | ||
return -1; | ||
} | ||
|
||
if (bytes_written > static_cast<size_t>(bytes_remaining)) { | ||
last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc wrote more bytes (", bytes_written, | ||
") than requested (", bytes_remaining, ")."); | ||
return -1; | ||
} | ||
|
||
if (bytes_written == 0) { | ||
last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc failed to write any data. ", | ||
"Stopping write attempts to avoid a potential infinite loop."); | ||
return -1; | ||
} | ||
|
||
bytes_remaining -= static_cast<std::ptrdiff_t>(bytes_written); | ||
ptr += bytes_written; | ||
} | ||
|
||
assert(ptr == pptr()); | ||
pbump(-static_cast<int>(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ | ||
return 0; | ||
} | ||
|
||
} // namespace epctx | ||
} // namespace onnxruntime |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need
num_bytes_written
? doesn't that complicate things vs. requiring the function implementer to handle all bytes they were given?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call. Removed.