diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index f0e5fbbd38721..6244d426450a2 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, nullptr, allocators_); + logger_, profiler_, sess_options_, + prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 82e5efd92a8f1..e1ce1d4abf81d 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -850,6 +850,130 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { ASSERT_EQ(session_state_2.GetUsedSharedPrePackedWeightCounter(), static_cast(1)); } +// Pre-packing enabled + shared initializers + +// pre-packed weights container + subgraphs = +// caching enabled in pre-packed weights used in subgraphs +TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { + SessionOptions sess_options; + sess_options.enable_mem_pattern = true; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = true; + // Enable pre-packing + sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "0"; + + // Enable shared initializer + OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator); + std::vector float_data(1, 1); + auto value = std::make_unique(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(std::vector{1}), + reinterpret_cast(float_data.data()), mem_info, *value); + + ASSERT_STATUS_OK(sess_options.AddInitializer("if_shared", value.get())); + + // Enable pre-packed weights container + PrepackedWeightsContainer prepacked_weights_container; + + // First session/model + Model model_1("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_1.MainGraph()); + PlaceAllNodesToCPUEP(model_1.MainGraph()); + SessionState session_state_1(model_1.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_1.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_1 = 1; + if (session_state_1.GetKernel(0)->Node().OpType() == "If") { + if_index_1 = 0; + } + + const auto& subgraph_session_states = session_state_1.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index_1); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + auto if_node_branches_prepack_counter_1 = + session_state_1_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_1_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_1, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_1 = + session_state_1_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_1_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should only be seeing 1 shared pre-pack weights usage in the "If" node + // Either the "then branch" or "else branch" will be using the shared version + // depending on which branch writes to the shared container + ASSERT_EQ(if_node_branches_shared_prepack_counter_1, static_cast(1)); + + // Second session/model + Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_2.MainGraph()); + PlaceAllNodesToCPUEP(model_2.MainGraph()); + SessionState session_state_2(model_2.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_2.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_2 = 1; + if (session_state_2.GetKernel(0)->Node().OpType() == "If") { + if_index_2 = 0; + } + + const auto& subgraph_session_states_2 = session_state_2.GetSubgraphSessionStateMap(); + const auto& if_node_session_states_2 = subgraph_session_states_2.at(if_index_2); + const auto& session_state_2_then_branch_session_state = *if_node_session_states_2.at("then_branch"); + const auto& session_state_2_else_branch_session_state = *if_node_session_states_2.at("else_branch"); + + auto if_node_branches_prepack_counter_2 = + session_state_2_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_2_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_2, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_2 = + session_state_2_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_2_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should be seeing 2 shared pre-pack weights calls in the "If" node + // Both branches will be using the shared version coming from the first model. + ASSERT_EQ(if_node_branches_shared_prepack_counter_2, static_cast(2)); +} + INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(PrepackingTestParam{false, false},