diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 25b6d72394e0c..e8420a12fe01e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -453,6 +453,20 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Function that writes a buffer to a stream. + * + * \param steam_state Opaque pointer holding the state for the user's stream. + * \param buffer The buffer to write to the stream. + * \param buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtOutStreamWriteFunc)(_In_ void* stream_state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -5831,8 +5845,9 @@ struct OrtCompileApi { /** \brief Sets the file path to the input ONNX model to compile. * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). @@ -5846,8 +5861,9 @@ struct OrtCompileApi { /** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] input_model_data Buffer containing the loaded ONNX model bytes. @@ -5864,10 +5880,10 @@ struct OrtCompileApi { /** \brief Sets the file path for the output ONNX model generated by CompileModel. * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on * the input model's file path. Examples: * /Path/my_model.onnx -> /Path/my_model_ctx.onnx * /Path/my_model -> /Path/my_model_ctx.onnx @@ -5906,10 +5922,10 @@ struct OrtCompileApi { * * The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on * the input model's file path. Examples: * /Path/my_model.onnx -> /Path/my_model_ctx.onnx * /Path/my_model -> /Path/my_model_ctx.onnx @@ -5964,6 +5980,47 @@ struct OrtCompileApi { * \since Version 1.22. */ ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); + + /** \brief Sets the input OrtModel instance to compile. + * + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model The OrtModel instance to compile. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* input_model); + + /** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. + * + * The provided write function may be called repeatedly until then entire output model has been written out. Each call + * to the write function is expected to write the entire buffer to the underlying stream. + * + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. + * + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on + * the input model's file path. Examples: + * /Path/my_model.onnx -> /Path/my_model_ctx.onnx + * /Path/my_model -> /Path/my_model_ctx.onnx + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_stream_func The OrtOutStreamWriteFunc function to call when writing out the model. + * \param[in] state Opaque stream state passed as the first argument to OrtOutStreamWriteFunc. Can be null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelOutStream, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state); }; ORT_RUNTIME_CLASS(Ep); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 8876c40fe9e6c..cebf1404b631a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1138,38 +1138,6 @@ struct SessionOptions : detail::SessionOptionsImpl { ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } }; -/** \brief Options object used when compiling a model. - * - * Wraps ::OrtModelCompilationOptions object and methods - */ -struct ModelCompilationOptions : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used. - - ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions - ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions - - ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath - ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, - size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer - ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode - ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath - ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, - size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile - ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, - size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer -}; - -/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. - * - * \param env: ORT environment object. - * \param model_compilation_options: Compilation options for a model. - * \return A Status indicating success or failure. - */ -Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); - /** \brief Wrapper around ::OrtModelMetadata * */ @@ -2873,5 +2841,40 @@ struct Model : detail::ModelImpl { ConstModel GetConst() const { return ConstModel{this->p_}; } }; + +/** \brief Options object used when compiling a model. + * + * Wraps ::OrtModelCompilationOptions object and methods + */ +struct ModelCompilationOptions : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used. + + ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + + ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath + ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, + size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer + ModelCompilationOptions& SetInputModel(ConstModel input_model); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModel + ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode + ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath + ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, + size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, + size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelOutStream +}; + +/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. + * + * \param env: ORT environment object. + * \param model_compilation_options: Compilation options for a model. + * \return A Status indicating success or failure. + */ +Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); + } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0d0b3198a8736..883e29bfd9a37 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -801,6 +801,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel( + ConstModel input_model) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, input_model)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( const ORTCHAR_T* output_model_path) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path)); @@ -824,6 +830,14 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelOutStream( + OrtOutStreamWriteFunc write_stream_func, void* stream_state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelOutStream(this->p_, + write_stream_func, + stream_state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc new file mode 100644 index 0000000000000..60d2e25cc619f --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/ep_context_options.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { +namespace epctx { +// class ModelGenOptions + +ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { + enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + + std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (!output_model_path.empty()) { + output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + } else { + output_model_location = std::monostate{}; + } + + output_external_initializers_file_path = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + output_external_initializer_size_threshold = 0; + embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; +} + +bool ModelGenOptions::HasOutputModelLocation() const { + return !std::holds_alternative(output_model_location); +} + +const std::string* ModelGenOptions::TryGetOutputModelPath() const { + return std::get_if(&output_model_location); +} + +const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { + return std::get_if(&output_model_location); +} + +const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const { + return std::get_if(&output_model_location); +} + +// class OutStreamBuf + +OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) { + setp(buffer_.data(), buffer_.data() + buffer_.size()); +} + +OutStreamBuf::~OutStreamBuf() { + sync(); +} + +// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. +std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (sync() == -1) { + return traits_type::eof(); + } + + if (ch != traits_type::eof()) { + *pptr() = static_cast(ch); + pbump(1); + } + + return ch; +} + +// Flushes the entire buffer_ to the user's write function. +int OutStreamBuf::sync() { + if (!last_status_.IsOK()) { + return -1; + } + + std::ptrdiff_t num_bytes = pptr() - pbase(); + if (num_bytes == 0) { + return 0; + } + + // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. + if (num_bytes > std::numeric_limits::max()) { + num_bytes = std::numeric_limits::max(); + } + + char* ptr = pbase(); + + Status status = Status::OK(); + + ORT_TRY { + status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state, + ptr, num_bytes)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; + } + + pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ + return 0; +} + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h new file mode 100644 index 0000000000000..7e46e664f90ce --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.h @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" + +namespace onnxruntime { +namespace epctx { +/// +/// Holds the buffer that will store the output model and the allocator used to allocate the memory. +/// +struct BufferHolder { + void** buffer_ptr = nullptr; + size_t* buffer_size_ptr = nullptr; + AllocatorPtr buffer_allocator = nullptr; +}; + +/// +/// Holds the opaque stream state and the write function that ORT calls to write out the output model. +/// +struct OutStreamHolder { + OrtOutStreamWriteFunc write_func = nullptr; + void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. +}; + +/// +/// Stores EPContext model generation options. Used in SessionOptions. +/// +struct ModelGenOptions { + ModelGenOptions() = default; + + // Initializes from string key/value pairs in session config options. + explicit ModelGenOptions(const ConfigOptions& config_options); + + bool enable = false; + bool overwrite_existing_output_file = false; + bool error_if_no_compiled_nodes = false; + bool embed_ep_context_in_model = false; + + std::variant // Function to write the output model to a user's stream. + output_model_location{}; + + std::string output_external_initializers_file_path; + size_t output_external_initializer_size_threshold = 0; + + bool HasOutputModelLocation() const; + const std::string* TryGetOutputModelPath() const; + const BufferHolder* TryGetOutputModelBuffer() const; + const OutStreamHolder* TryGetOutputModelOutStream() const; +}; + +// Class that wraps the user's OrtOutStreamWriteFunc function to enable use with +// C++'s std::ostream. +// Example: +// OutStreamHolder stream_holder{write_func, stream_state}; +// std::unique_ptr out_stream_buf = std::make_unique(stream_holder); +// std::ostream out_stream(out_stream_buf.get()); +class OutStreamBuf : public std::streambuf { + public: + explicit OutStreamBuf(OutStreamHolder out_stream_holder); + ~OutStreamBuf(); + + const Status& GetStatus() const { + return last_status_; + } + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + OutStreamHolder out_stream_holder_{}; + std::array buffer_{}; + Status last_status_{}; +}; + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 8ed5eeaa8d44f..9465b8a7d4c87 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -9,6 +9,7 @@ #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/ep_context_options.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" @@ -794,7 +795,7 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -824,15 +825,16 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr && - ep_context_gen_options.output_model_buffer_size_ptr != nullptr && - ep_context_gen_options.output_model_buffer_allocator != nullptr; + const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const epctx::OutStreamHolder* output_stream_holder = ep_context_gen_options.TryGetOutputModelOutStream(); + const std::string* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); - std::filesystem::path context_cache_path; - if (!saving_to_buffer || !graph.ModelPath().empty()) { - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, + std::filesystem::path valid_output_model_path; + if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + std::string output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr : ""; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, graph.ModelPath(), - context_cache_path, + valid_output_model_path, ep_context_gen_options.overwrite_existing_output_file)); } @@ -894,25 +896,41 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ModelSavingOptions model_saving_options{ini_size_threshold}; - if (saving_to_buffer) { + if (output_buffer_holder != nullptr) { + // Write output model into a buffer ORT allocates for the user. ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); - // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. - // May be able to use allocator to directly allocate the ModelProto to avoid a copy. ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, - context_cache_path, + valid_output_model_path, model_saving_options); size_t buffer_size = model_proto.ByteSizeLong(); ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), "Cannot serialize ONNX ModelProto larger than 2GB"); - AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator; + AllocatorPtr allocator = output_buffer_holder->buffer_allocator; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size); model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size)); - *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size; - *ep_context_gen_options.output_model_buffer_ptr = buffer.release(); + *output_buffer_holder->buffer_size_ptr = buffer_size; + *output_buffer_holder->buffer_ptr = buffer.release(); + } else if (output_stream_holder != nullptr) { + // Write output model to user's output stream. + ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); + ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, + valid_output_model_path, + model_saving_options); + size_t buffer_size = model_proto.ByteSizeLong(); + ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), + "Cannot serialize ONNX ModelProto larger than 2GB"); + + auto out_stream_buf = std::make_unique(*output_stream_holder); + std::ostream out_stream(out_stream_buf.get()); + + model_proto.SerializeToOstream(&out_stream); + out_stream.flush(); + ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { - ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path, + // Write output model to file. + ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, valid_output_model_path, external_ini_path, model_saving_options)); } @@ -1164,7 +1182,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const ConfigOptions& config_options, const logging::Logger& logger, Mode mode, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -1211,12 +1229,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) { - // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), - context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + if (ep_context_gen_options.enable) { + if (const std::string* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + output_model_path_ptr != nullptr) { + // Check before EP compile graphs + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(), + context_cache_path, + ep_context_gen_options.overwrite_existing_output_file)); + } } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 6e36d79701fd7..abe46cea58ab2 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,7 +15,10 @@ class ExecutionProviders; class KernelRegistryManager; class Model; struct ConfigOptions; -struct EpContextModelGenerationOptions; + +namespace epctx { +struct ModelGenOptions; +} class GraphPartitioner { public: @@ -50,7 +53,7 @@ class GraphPartitioner { const ConfigOptions& config_options, const logging::Logger& logger, Mode mode = Mode::kNormal, - const EpContextModelGenerationOptions& ep_context_gen_options = {}, + const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; bool IsLoadCancellationFlagSet() const { diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 231eb47603838..63f928d52d788 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) { - enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - output_external_initializers_file_path = config_options.GetConfigOrDefault( - kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); - output_external_initializer_size_threshold = 0; - embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; -} - -EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const { +epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const { if (this->has_explicit_ep_context_gen_options) { return this->ep_context_gen_options; } - return EpContextModelGenerationOptions(this->config_options); + return epctx::ModelGenOptions(this->config_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 89a43c4f71ee6..e44ac95c2c890 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -11,8 +11,8 @@ #include #include #include "core/common/inlined_containers.h" -#include "core/framework/allocator.h" #include "core/framework/config_options.h" +#include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" @@ -70,26 +70,6 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; -struct EpContextModelGenerationOptions { - EpContextModelGenerationOptions() = default; - - // Initializes from string key/value pairs in session config options. - explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); - - bool enable = false; - bool overwrite_existing_output_file = false; - bool error_if_no_compiled_nodes = false; - bool embed_ep_context_in_model = false; - - std::string output_model_file_path; - void** output_model_buffer_ptr = nullptr; - size_t* output_model_buffer_size_ptr = nullptr; - AllocatorPtr output_model_buffer_allocator = nullptr; - - std::string output_external_initializers_file_path; - size_t output_external_initializer_size_threshold = 0; -}; - struct EpSelectionPolicy { // flag to detect that a policy was set by the user. // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered @@ -239,12 +219,12 @@ struct SessionOptions { // Options for generating compile EPContext models were previously stored in session_option.configs as // string key/value pairs. To support more advanced options, such as setting input/output buffers, we - // now have to store EPContext options in a struct of type EpContextModelGenerationOptions. + // now have to store EPContext options in a struct of type epctx::ModelGenOptions. // The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new // struct type. bool has_explicit_ep_context_gen_options = false; - EpContextModelGenerationOptions ep_context_gen_options = {}; - EpContextModelGenerationOptions GetEpContextGenerationOptions() const; + epctx::ModelGenOptions ep_context_gen_options = {}; + epctx::ModelGenOptions GetEpContextGenerationOptions() const; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 94a2a6677358e..2efa03bd3adf1 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -179,11 +179,17 @@ Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); unpacked_tensor.resize(tensor_byte_size); - ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( - external_file_path.c_str(), - file_offset, - tensor_byte_size, - gsl::make_span(reinterpret_cast(unpacked_tensor.data()), tensor_byte_size))); + if (external_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { + // the value in location is the memory address of the data + const void* ext_data_buf = reinterpret_cast(file_offset); + std::memcpy(unpacked_tensor.data(), ext_data_buf, tensor_byte_size); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( + external_file_path.c_str(), + file_offset, + tensor_byte_size, + gsl::make_span(reinterpret_cast(unpacked_tensor.data()), tensor_byte_size))); + } return Status::OK(); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index d72bd13093b61..2ee84958e4d8b 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "core/common/inlined_containers_fwd.h" #include "core/framework/ort_value.h" diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index ad128fee6cc3d..12f1745a99678 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -10,6 +10,7 @@ #include "core/common/common.h" #include "core/session/allocator_adapters.h" #include "core/framework/error_code_helper.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" #include "core/session/model_compilation_options.h" @@ -106,6 +107,27 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelFromBuff API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ const OrtModel* input_model) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (input_model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: OrtModel pointer is null"); + } + + model_compile_options->SetInputOrtModel(*input_model); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(input_model); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* output_model_path) { @@ -185,6 +207,28 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStream, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (write_stream_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtOutStreamWriteFunc function for output model is null"); + } + + model_compile_options->SetOutputModelOutStream(write_stream_func, stream_state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_stream_func); + ORT_UNUSED_PARAMETER(stream_state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -229,6 +273,8 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, + &OrtCompileAPI::ModelCompilationOptions_SetInputModel, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStream, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index b8c5211526b9d..231bd9a82eab5 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -28,5 +28,9 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelC ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* input_model); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelOutStream, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index d0cb092f78843..eed87a69e2071 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -9,6 +9,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/ep_context_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -29,27 +30,28 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& } void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { - ResetInputModelSettings(); - input_model_path_ = input_model_path; + input_model_variant_ = input_model_path; } void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size) { - ResetInputModelSettings(); - input_model_data_ = input_model_data; - input_model_data_size_ = input_model_data_size; + input_model_variant_ = gsl::span(reinterpret_cast(input_model_data), + input_model_data_size); +} + +void ModelCompilationOptions::SetInputOrtModel(const OrtModel& ort_model) { + input_model_variant_ = &ort_model; } Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) { ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); ConfigOptions& config_options = session_options_.value.config_options; - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path = output_model_path; + ep_context_gen_options.output_model_location = output_model_path; - if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) { - Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ep_context_gen_options.output_model_file_path.c_str()); + if (output_model_path.size() <= ConfigOptions::kMaxValueLength) { + Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path.c_str()); ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths // required by ConfigOptions::AddConfigEntry(). } else { @@ -70,7 +72,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() + LOGS(logger, WARNING) << "Output model path length (" << output_model_path.size() << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." << "ORT will still generated the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; @@ -91,12 +93,21 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a size_t* output_model_buffer_size_ptr) { ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); - session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{ + output_model_buffer_ptr, + output_model_buffer_size_ptr, + std::move(allocator), + }; return Status::OK(); } +void ModelCompilationOptions::SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::OutStreamHolder{ + write_stream_func, + stream_state, + }; +} + Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -108,98 +119,93 @@ const OrtSessionOptions& ModelCompilationOptions::GetSessionOptions() const { return session_options_; } -bool ModelCompilationOptions::InputModelComesFromFile() const { - return !input_model_path_.empty(); -} - -const std::string& ModelCompilationOptions::GetInputModelPath() const { - return input_model_path_; -} - -const void* ModelCompilationOptions::GetInputModelData() const { - return input_model_data_; +const std::string* ModelCompilationOptions::TryGetInputModelPath() const { + return std::get_if(&input_model_variant_); } -size_t ModelCompilationOptions::GetInputModelDataSize() const { - return input_model_data_size_; +const gsl::span* ModelCompilationOptions::TryGetInputModelBuffer() const { + return std::get_if>(&input_model_variant_); } -void ModelCompilationOptions::ResetInputModelSettings() { - input_model_path_.clear(); - input_model_data_ = nullptr; - input_model_data_size_ = 0; +const OrtModel* ModelCompilationOptions::TryGetInputOrtModel() const { + const gsl::not_null* ort_model_ptr_ptr = std::get_if>( + &input_model_variant_); + return (ort_model_ptr_ptr == nullptr) ? nullptr : ort_model_ptr_ptr->get(); } Status ModelCompilationOptions::ResetOutputModelSettings() { - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path.clear(); - ep_context_gen_options.output_model_buffer_ptr = nullptr; - ep_context_gen_options.output_model_buffer_size_ptr = nullptr; - ep_context_gen_options.output_model_buffer_allocator = nullptr; return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); } -Status ModelCompilationOptions::CheckInputModelSettings() const { - const bool comes_from_file = !input_model_path_.empty(); - const bool comes_from_memory = input_model_data_ != nullptr; +Status ModelCompilationOptions::Check() const { + ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); + ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - if (!comes_from_file && !comes_from_memory) { + // Check input model settings + if (std::holds_alternative(input_model_variant_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer"); + "Input model to compile must be loaded from either a file, a memory buffer, ", + "or an OrtModel instance"); } - if (comes_from_file && comes_from_memory) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer, ", - "but not both."); - } + const std::string* input_model_path_ptr = TryGetInputModelPath(); + const gsl::span* input_model_buffer_ptr = TryGetInputModelBuffer(); + const OrtModel* ort_model = TryGetInputOrtModel(); - if (comes_from_file && !std::filesystem::exists(input_model_path_)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_); + if (input_model_path_ptr != nullptr && !std::filesystem::exists(*input_model_path_ptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", *input_model_path_ptr); } - if (comes_from_memory && input_model_data_size_ == 0) { + if (input_model_buffer_ptr != nullptr && input_model_buffer_ptr->size() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } - return Status::OK(); -} + if (ort_model != nullptr && ort_model->graph->nodes.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input OrtModel instance has no nodes"); + } -Status ModelCompilationOptions::CheckOutputModelSettings() const { - const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + // Check output model settings + const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; + const bool has_no_output_model_location = std::holds_alternative( + ep_context_gen_options.output_model_location); - if (!explicit_writes_to_file && !writes_to_buffer) { - // User did not specify an output file or an output buffer. We default to generating an output file + if (has_no_output_model_location && input_model_path_ptr != nullptr) { + // User did not specify an output file, output buffer, or output stream. We default to generating an output file // with a name based on the input file name, so do not return an error. return Status::OK(); } - if (explicit_writes_to_file && writes_to_buffer) { + if (has_no_output_model_location) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Output model to compile must be saved either to a file or to a buffer, but not both."); + "Unable to generate an output model path: require an input model path if the location " + "of the output model (e.g., file, buffer, or stream) is not specified."); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer(); + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer configuration for output model: buffer pointer is null"); + } + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } - return Status::OK(); -} + const epctx::OutStreamHolder* output_stream_ptr = ep_context_gen_options.TryGetOutputModelOutStream(); + + if (output_stream_ptr != nullptr && output_stream_ptr->write_func == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid write-to-stream function for output model: function pointer is null"); + } -Status ModelCompilationOptions::Check() const { - ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); - ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - ORT_RETURN_IF_ERROR(CheckInputModelSettings()); - ORT_RETURN_IF_ERROR(CheckOutputModelSettings()); return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 9238264003645..a220878b37cab 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -4,11 +4,14 @@ #if !defined(ORT_MINIMAL_BUILD) #pragma once +#include #include #include +#include #include "core/common/status.h" #include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -44,6 +47,13 @@ class ModelCompilationOptions { /// The size in bytes of the input model's buffer void SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size); + /// + /// Sets the input OrtModel instance. + /// Overrides any previous call to SetInput*() + /// + /// OrtModel instance + void SetInputOrtModel(const OrtModel& ort_model); + /// /// Sets the file path to store the output/compiled ONNX model. /// Overrides any previous call to SetOutputModelPath() or SetOutputModelBuffer(). @@ -72,6 +82,13 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets an output stream (write function + state) used to write out the compiled model. + /// + /// Write function + /// The stream state + void SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state); + /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). @@ -87,30 +104,25 @@ class ModelCompilationOptions { const OrtSessionOptions& GetSessionOptions() const; /// - /// Returns the file path to the input ONNX model. + /// Returns a pointer to the input model's path or nullptr if the input model + /// is not read from file. /// - /// input model's path - const std::string& GetInputModelPath() const; + /// input model's path or nullptr + const std::string* TryGetInputModelPath() const; /// - /// Returns true if the input model is read from a file. + /// Returns a pointer to the input model's bytes buffer or nullptr if the input model + /// is not read from a buffer. /// - /// true if input model comes from a file - bool InputModelComesFromFile() const; + /// input model's bytes buffer or nullptr + const gsl::span* TryGetInputModelBuffer() const; /// - /// Returns the buffer that contains the bytes for the input ONNX model. - /// Returns nullptr if the input model is not stored in a buffer. + /// Returns a pointer to the OrtModel instance for the input model or nullptr if + /// the input model is not stored as an OrtModel. /// - /// pointer to input model's buffer - const void* GetInputModelData() const; - - /// - /// Returns the size in bytes of the buffer that contains the input ONNX model. - /// Returns 0 if the input model is not stored in a buffer. - /// - /// input model buffer's size in bytes - size_t GetInputModelDataSize() const; + /// The OrtModel or nullptr + const OrtModel* TryGetInputOrtModel() const; /// /// Checks if the compilation options described by this object are valid. @@ -119,16 +131,16 @@ class ModelCompilationOptions { Status Check() const; private: - void ResetInputModelSettings(); Status ResetOutputModelSettings(); - Status CheckInputModelSettings() const; - Status CheckOutputModelSettings() const; const onnxruntime::Environment& env_; OrtSessionOptions session_options_; - std::string input_model_path_; - const void* input_model_data_ = nullptr; - size_t input_model_data_size_ = 0; + + std::variant, // input model in buffer + gsl::not_null> // input model created via OrtModelEditor + input_model_variant_{}; }; } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc index 7d5d45b7b531b..04b3a1cf87500 100644 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -216,12 +216,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSessionFromModel, _In_ const OrtEnv *out = nullptr; ORT_TRY { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); - + ORT_API_RETURN_IF_STATUS_NOT_OK(onnxruntime::CreateSessionFromModel(options, env->GetEnvironment(), model, sess)); ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 8ca4ef6af1f44..c550640e60c63 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -139,13 +139,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file if (options && model_path == nullptr) { - EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); // This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case // the user used the older SessionOptions' configuration entries to generate a compiled model. - if (ep_ctx_gen_options.enable && - ep_ctx_gen_options.output_model_file_path.empty() && - ep_ctx_gen_options.output_model_buffer_ptr == nullptr) { + if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) { return OrtApis::CreateStatus(ORT_FAIL, "Inference session was configured with EPContext model generation enabled but " "without a valid location (e.g., file or buffer) for the output model. " @@ -262,22 +260,48 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +// Creates a session from an OrtModel (created by model editor API). +Status CreateSessionFromModel(const OrtSessionOptions* options, + const Environment& env, + const OrtModel* ort_model, + std::unique_ptr& sess) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env); + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + if (options && !options->custom_op_domains_.empty()) { + ORT_RETURN_IF_ERROR(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + ORT_RETURN_IF_ERROR(sess->Load(*ort_model)); + return Status::OK(); +} + Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options) { ORT_RETURN_IF_ERROR(model_compile_options.Check()); std::unique_ptr session; const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); - if (model_compile_options.InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + if (const std::string* model_path_ptr = model_compile_options.TryGetInputModelPath(); + model_path_ptr != nullptr) { + PathString input_model_path = ToPathString(*model_path_ptr); ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session))); - } else { + } else if (const gsl::span* model_buffer_ptr = model_compile_options.TryGetInputModelBuffer(); + model_buffer_ptr != nullptr) { ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, nullptr, - model_compile_options.GetInputModelData(), - model_compile_options.GetInputModelDataSize(), + model_buffer_ptr->data(), + model_buffer_ptr->size(), session))); + } else if (const OrtModel* ort_model = model_compile_options.TryGetInputOrtModel(); ort_model != nullptr) { + ORT_RETURN_IF_ERROR(CreateSessionFromModel(session_options, env, ort_model, session)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid input model passed to CompileModel(): ", + "expected a file path, a buffer, or an OrtModel, but received an unknown kind of model."); } ORT_RETURN_IF_ERROR(ToStatus(InitializeSession(session_options, *session))); diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 5a5dcae9165ed..9c5f7e34a3d80 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -21,6 +21,7 @@ class ModelCompilationOptions; } // namespace onnxruntime #if !defined(ORT_MINIMAL_BUILD) +struct OrtModel; namespace onnxruntime { class Environment; class EpLibrary; @@ -44,6 +45,19 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, #if !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { +/// +/// Creates a session from an OrtModel instance. +/// +/// Optional session options +/// Reference to the Environment +/// The OrtModel instance +/// The session to create and initialize +/// A Status indicating an error or success. +Status CreateSessionFromModel(const OrtSessionOptions* options, + const onnxruntime::Environment& env, + const OrtModel* ort_model, + std::unique_ptr& sess); + /// /// Compiles an ONNX model into a model with EPContext nodes. Each EPContext node represents a subgraph compiled for /// a specific execution provider. diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index e7dc4294f3672..b2ed617a74e8b 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,7 @@ import os import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any from onnxruntime.capi import _pybind_state as C @@ -726,6 +726,16 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() + def compile_to_stream(self, write_function: Callable[[bytes], None]): + """ + Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :param write_function: A callable that accepts a bytes buffer to write. + """ + self._model_compiler.compile_to_stream(write_function) + class IOBinding: """ diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index 8bb7ee2098caf..f75931bd6b170 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" #include @@ -72,9 +73,45 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) return Status::OK(); } +/** + * Calls the user's Python PyOutStreamWriteFunc function and converts the results to a form that can be used + * by ORT to write out a compiled ONNX model. + * + * @param stream_state Opaque state that holds a pointer to the user's Python function. + * @param buffer The buffer to write out. Contains a portion of the compiled ONNX model's bytes. + * @param buffer_num_bytes The number of bytes to write out. + * + * @return nullptr OrtStatus* to indicate success. + */ +static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, + size_t buffer_num_bytes) { + PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); + OrtStatus* status = nullptr; + + // Call the Python write function and convert any exceptions to a status. + ORT_TRY { + pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); + (*py_write_func)(py_bytes); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + return status; +} + +onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { + model_compile_options_.SetOutputModelOutStream(PyOutStreamWriteFuncWrapper, reinterpret_cast(&write_func)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + return Status::OK(); +} + PyModelCompiler::PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, PrivateConstructorTag) : env_(env), model_compile_options_(*env, sess_options) { } } // namespace python } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index 6c9f48fa00ba6..9fb59a33164b1 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -3,7 +3,6 @@ // Licensed under the MIT License. #pragma once -#if !defined(ORT_MINIMAL_BUILD) #include #include #include "core/common/status.h" @@ -14,11 +13,19 @@ namespace onnxruntime { class Environment; namespace python { +// Type of the function provided by Python code that is called by ORT to write out the compiled model. +// Returns the number of bytes written to the underlying stream. +using PyOutStreamWriteFunc = std::function; + /// /// Class exposed to Python that enables compiling ONNX models. /// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. /// class PyModelCompiler { +#if defined(ORT_MINIMAL_BUILD) + public: + bool not_defined_in_this_build{}; // Prevent empty class warning. +#else private: // private tag to pass to constructor to ensure that constructor cannot be directly called externally struct PrivateConstructorTag {}; @@ -68,11 +75,18 @@ class PyModelCompiler { /// A Status indicating error or success. onnxruntime::Status CompileToBytes(std::string& output_buffer); + /// + /// Compiles the input model and writes the result into the provided output stream (write functor). + /// + /// Write functor that encapsulates the stream's state. + /// A Status indicating error or success. + onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func); + private: std::shared_ptr env_; onnxruntime::ModelCompilationOptions model_compile_options_; std::string input_model_bytes_; +#endif // !defined(ORT_MINIMAL_BUILD) }; } // namespace python } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index aa2c0cc6a0f86..9e19561749ca1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2832,7 +2832,19 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("Compile API is not supported in this build."); #endif }, - R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc") + .def( + "compile_to_stream", + [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(py_stream_write_func); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); } bool CreateInferencePybindStateModule(py::module& m) { diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 76399743c97f8..3b90cfcce76db 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -11,6 +11,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" +#include "core/framework/ep_context_options.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, - EpContextModelGenerationOptions{}, + epctx::ModelGenOptions{}, debug_graph_fn)); verifier_fn(graph); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8d840b1a3d45f..180c1bae7d077 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include "core/session/onnxruntime_cxx_api.h" @@ -25,6 +26,14 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Returns QNN provider options that use the HTP backend (npu) and do not offload graph I/O qdq. +static ProviderOptions QnnHTPOptionsWithoutQDQOffloading() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { const auto& attributes = node.GetAttributes(); if (auto entry = attributes.find(attr_name); entry != attributes.end()) { @@ -557,6 +566,117 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu allocator.Free(output_model_buffer); } +// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + std::ofstream* outfile = reinterpret_cast(stream_state); + outfile->write(reinterpret_cast(buffer), buffer_num_bytes); + return nullptr; // No error +} + +// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); +} + +// Test using the CompileModel() API with settings: +// - input OrtModel created via the model editor API +// - write output model to custom stream +TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { + std::vector>> weights; // Model weights must remain valid through inference + // if we want to avoid a copy. + + // Create OrtModel with a Gemm. X input is 3x4, Y initializer is 4x8, Z output is 3x8. + Ort::Graph graph; + std::vector graph_inputs; + std::vector graph_outputs; + + Ort::TensorTypeAndShapeInfo x_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {3, 4}, nullptr); + auto x_type_info = Ort::TypeInfo::CreateTensorInfo(x_tensor_info.GetConst()); + graph_inputs.emplace_back("X", x_type_info.GetConst()); + + Ort::TensorTypeAndShapeInfo z_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {3, 8}, nullptr); + auto z_type_info = Ort::TypeInfo::CreateTensorInfo(z_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", z_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + std::vector attrs; + Ort::Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attrs); + graph.AddNode(node); + + std::vector y_dims = {4, 8}; + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto y_tensor = Ort::Value::CreateTensor(mem_info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external, avoid copy*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}, {onnxruntime::kMSDomain, 1}}; + Ort::Model model(opsets); + model.AddGraph(graph); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); + + const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_ortmodel_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Open an output file. Test will incrementally write the output model to file + // via calls to our OrtOutStreamWriteFunc callback. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); + std::ofstream outfile(output_model_file, std::ios::binary); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModel(model.GetConst()); + compile_options.SetOutputModelOutStream(TestWriteToStream, reinterpret_cast(&outfile)); // Set output stream + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + // Make sure the compiled model was generated (via our write stream function) + // and has the expected number of EPContext nodes. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 1, 0); +} + +// Tests using an OrtOutStreamFunc function that returns an error. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { + // Create a test model (in memory). + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + std::string model_data = test_model.Serialize(); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); + compile_options.SetOutputModelOutStream(ReturnStatusFromStream, nullptr); // Set output stream that returns error + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect a specific error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); +} + // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 7a410d4bbeb6a..d7012f1c9d5e8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -13,7 +13,7 @@ from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -176,6 +176,65 @@ def test_compile_from_buffer_to_buffer(self): self.assertTrue(isinstance(output_model_bytes, bytes)) self.assertGreater(len(output_model_bytes), 0) + def test_compile_from_file_to_stream(self): + """ + Tests compiling a model (from files) to an output stream using a custom write functor. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") + + with open(output_model_path, "wb") as output_fd: + # User's custom write functor. Writes the model to a file. + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + output_fd.write(buffer) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_stream(my_write_func) + + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_stream_that_raises_exception(self): + """ + Tests compiling a model to an output stream that always raises an exception. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + # User's custom write functor that raises an exception. + test_py_error_message = "My Python Error" + + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + raise ValueError(test_py_error_message) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + + # Try to compile and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_stream(my_write_func) + self.assertIn(test_py_error_message, str(context.exception)) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: