Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the context_file_path before EP compile graphs #23611

Merged
merged 5 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";

// Specify the file path for the Onnx model which has EP context.
// Default to original_file_name_ctx.onnx if not specified
// Folder is not a valid option
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";

// Flag to specify whether to dump the EP context into the Onnx model.
Expand Down
31 changes: 30 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,29 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
return Status::OK();
}

// Validate the ep_context_path to make sure it is file path and check whether the file exist already
static Status EpContextFilePathCheck(const std::string& ep_context_path,
const std::filesystem::path& model_path) {
std::filesystem::path context_cache_path;
if (!ep_context_path.empty()) {
context_cache_path = ep_context_path;
if (!context_cache_path.has_filename()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
}
} else if (!model_path.empty()) {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
}

if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
}

return Status::OK();
}

#endif

Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
Expand Down Expand Up @@ -1007,9 +1030,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,

if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
// Check before EP compile graphs
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
}

ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));

bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
std::string external_ini_file_name = config_options.GetConfigOrDefault(kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
Expand Down
107 changes: 107 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,113 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) {
EpCtxCpuNodeWithExternalIniFileTestBody(false);
}

// Set ep.context_file_path to folder path which is not a valid option, check the error message
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";

const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());

// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);

const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());

const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/";
std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);

try {
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder."));
}
}

// Create session 1 to generate context binary file
// Create session 2 to do same thing, make sure session 2 failed because file exist already
// Make sure no new file over write from session 2
TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";

const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
logging_manager.DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
bool single_ep_node = true;
BuildGraphWithQAndNonQ(single_ep_node)(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(model.MainGraph().Resolve());

// Serialize the model to a string.
std::string model_data;
model.ToProto().SerializeToString(&model_data);

const auto model_data_span = AsByteSpan(model_data.data(), model_data.size());

const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx";
const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin";

std::remove(ep_context_onnx_file.c_str());
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session1(*ort_env, model_data_span.data(), model_data_span.size(), so);

auto modify_time_1 = std::filesystem::last_write_time(ep_context_binary_file);

try {
Ort::Session session2(*ort_env, model_data_span.data(), model_data_span.size(), so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already."));
auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file);
ASSERT_EQ(modify_time_1, modify_time_2);
}

ASSERT_EQ(std::remove(ep_context_onnx_file.c_str()), 0);
ASSERT_EQ(std::remove(ep_context_binary_file.c_str()), 0);
}

// Create a model with Case + Add (quantized)
// cast_input -> Cast -> Q -> DQ \
// Add -> Q -> DQ -> output
Expand Down
Loading