Skip to content

Commit

Permalink
Validate the context_file_path before EP compile graphs (#23611)
Browse files Browse the repository at this point in the history
Validate the context_file_path before EP compile graphs to make it fail fast. To avoid the possibility that EP generate new file (context binary file or blob file) over write the existing file. Return error if the path points to folder.
  • Loading branch information
HectorSVC authored and ashrit-ms committed Feb 11, 2025
1 parent 4e1f33c commit 7f06469
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";

// Specify the file path for the Onnx model which has EP context.
// Default to original_file_name_ctx.onnx if not specified
// Folder is not a valid option
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";

// Flag to specify whether to dump the EP context into the Onnx model.
Expand Down
36 changes: 30 additions & 6 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,29 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
return Status::OK();
}

// Validate the ep_context_path to make sure it is file path and check whether the file exist already
static Status EpContextFilePathCheck(const std::string& ep_context_path,
const std::filesystem::path& model_path) {
std::filesystem::path context_cache_path;
if (!ep_context_path.empty()) {
context_cache_path = ep_context_path;
if (!context_cache_path.has_filename()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
}
} else if (!model_path.empty()) {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
}

if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
}

return Status::OK();
}

static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
const Graph& graph,
const std::filesystem::path& ep_context_path,
Expand Down Expand Up @@ -678,11 +701,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty");
}

if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
}

Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()},
Expand Down Expand Up @@ -1007,9 +1025,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,

if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
// Check before EP compile graphs
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
}

ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));

bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
Expand Down
107 changes: 107 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) {
EpCtxCpuNodeWithExternalIniFileTestBody(false);
}

// Set ep.context_file_path to folder path which is not a valid option, check the error message
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";

const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());

// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);

const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());

const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/";
std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);

try {
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder."));
}
}

// Create session 1 to generate context binary file
// Create session 2 to do same thing, make sure session 2 failed because file exist already
// Make sure no new file over write from session 2
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";

const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());

// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);

const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());

const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx";
const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin";

std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so);

auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file);

try {
Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already."));
auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file);
ASSERT_EQ(modify_time_1, modify_time_2);
}

ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0);
ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0);
}

// Create a model with Case + Add (quantized)
// cast_input -> Cast -> Q -> DQ \
// Add -> Q -> DQ -> output
Expand Down

0 comments on commit 7f06469

Please sign in to comment.