From 614af3742c3f1f3f2aa91a5699a6de1c9fe79fbb Mon Sep 17 00:00:00 2001 From: Brian Lambert <98757707+brian-pieces@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:01:41 -0400 Subject: [PATCH] Add prepacked weights container to subgraphs (#17671) ### Description Adds prepacked weights container to model subgraphs. ### Motivation and Context Allows for initializer sharing when the initializers are located in subgraphs. I encountered this bug when attempting to share weights between T5 BeamSearch models where the shareable initializers are located in the encoder and decoder subgraphs and it failed to reduce memory usage. --- onnxruntime/core/framework/session_state.cc | 3 +- .../test/framework/session_state_test.cc | 124 ++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index f0e5fbbd38721..6244d426450a2 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, nullptr, allocators_); + logger_, profiler_, sess_options_, + prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 82e5efd92a8f1..e1ce1d4abf81d 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -850,6 +850,130 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { ASSERT_EQ(session_state_2.GetUsedSharedPrePackedWeightCounter(), static_cast(1)); } +// Pre-packing enabled + shared initializers + +// pre-packed weights container + subgraphs = +// caching enabled in pre-packed weights used in subgraphs +TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { + SessionOptions sess_options; + sess_options.enable_mem_pattern = true; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = true; + // Enable pre-packing + sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "0"; + + // Enable shared initializer + OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator); + std::vector float_data(1, 1); + auto value = std::make_unique(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(std::vector{1}), + reinterpret_cast(float_data.data()), mem_info, *value); + + ASSERT_STATUS_OK(sess_options.AddInitializer("if_shared", value.get())); + + // Enable pre-packed weights container + PrepackedWeightsContainer prepacked_weights_container; + + // First session/model + Model model_1("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_1.MainGraph()); + PlaceAllNodesToCPUEP(model_1.MainGraph()); + SessionState session_state_1(model_1.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_1.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_1 = 1; + if (session_state_1.GetKernel(0)->Node().OpType() == "If") { + if_index_1 = 0; + } + + const auto& subgraph_session_states = session_state_1.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index_1); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + auto if_node_branches_prepack_counter_1 = + session_state_1_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_1_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_1, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_1 = + session_state_1_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_1_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should only be seeing 1 shared pre-pack weights usage in the "If" node + // Either the "then branch" or "else branch" will be using the shared version + // depending on which branch writes to the shared container + ASSERT_EQ(if_node_branches_shared_prepack_counter_1, static_cast(1)); + + // Second session/model + Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_2.MainGraph()); + PlaceAllNodesToCPUEP(model_2.MainGraph()); + SessionState session_state_2(model_2.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_2.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_2 = 1; + if (session_state_2.GetKernel(0)->Node().OpType() == "If") { + if_index_2 = 0; + } + + const auto& subgraph_session_states_2 = session_state_2.GetSubgraphSessionStateMap(); + const auto& if_node_session_states_2 = subgraph_session_states_2.at(if_index_2); + const auto& session_state_2_then_branch_session_state = *if_node_session_states_2.at("then_branch"); + const auto& session_state_2_else_branch_session_state = *if_node_session_states_2.at("else_branch"); + + auto if_node_branches_prepack_counter_2 = + session_state_2_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_2_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_2, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_2 = + session_state_2_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_2_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should be seeing 2 shared pre-pack weights calls in the "If" node + // Both branches will be using the shared version coming from the first model. + ASSERT_EQ(if_node_branches_shared_prepack_counter_2, static_cast(2)); +} + INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(PrepackingTestParam{false, false},