Skip to content

Commit

Permalink
Add prepacked weights container to subgraphs (microsoft#17671)
Browse files Browse the repository at this point in the history
### Description
Adds prepacked weights container to model subgraphs.



### Motivation and Context
Allows for initializer sharing when the initializers are located in
subgraphs. I encountered this bug when attempting to share weights
between T5 BeamSearch models where the shareable initializers are
located in the encoder and decoder subgraphs and it failed to reduce
memory usage.
  • Loading branch information
brian-at-pieces authored Sep 26, 2023
1 parent 0141e27 commit 614af37
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() {
auto subgraph_session_state =
std::make_unique<SessionState>(*subgraph, execution_providers_,
thread_pool_, inter_op_thread_pool_, data_transfer_mgr_,
logger_, profiler_, sess_options_, nullptr, allocators_);
logger_, profiler_, sess_options_,
prepacked_weights_container_, allocators_);

// Pass fused function manager to subgraph
subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_);
Expand Down
124 changes: 124 additions & 0 deletions onnxruntime/test/framework/session_state_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,130 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) {
ASSERT_EQ(session_state_2.GetUsedSharedPrePackedWeightCounter(), static_cast<size_t>(1));
}

// Pre-packing enabled + shared initializers +
// pre-packed weights container + subgraphs =
// caching enabled in pre-packed weights used in subgraphs
TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) {
SessionOptions sess_options;
sess_options.enable_mem_pattern = true;
sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
sess_options.use_deterministic_compute = false;
sess_options.enable_mem_reuse = true;
// Enable pre-packing
sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "0";

// Enable shared initializer
OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator);
std::vector<float> float_data(1, 1);
auto value = std::make_unique<OrtValue>();
Tensor::InitOrtValue(DataTypeImpl::GetType<float>(), TensorShape(std::vector<int64_t>{1}),
reinterpret_cast<void*>(float_data.data()), mem_info, *value);

ASSERT_STATUS_OK(sess_options.AddInitializer("if_shared", value.get()));

// Enable pre-packed weights container
PrepackedWeightsContainer prepacked_weights_container;

// First session/model
Model model_1("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());

CreateGraphWithSubgraph(model_1.MainGraph());
PlaceAllNodesToCPUEP(model_1.MainGraph());
SessionState session_state_1(model_1.MainGraph(),
execution_providers,
tp.get(),
nullptr, /*inter_op_thread_pool*/
dtm,
DefaultLoggingManager().DefaultLogger(),
profiler,
sess_options,
&prepacked_weights_container);

ASSERT_STATUS_OK(session_state_1.FinalizeSessionState(std::basic_string<PATH_CHAR_TYPE>(),
kernel_registry_manager));

// At the main graph level, there should be no pre-packing calls as there are
// no initializers (shared or otherwise) consumed by any nodes in the main graph
ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast<size_t>(0));

auto if_index_1 = 1;
if (session_state_1.GetKernel(0)->Node().OpType() == "If") {
if_index_1 = 0;
}

const auto& subgraph_session_states = session_state_1.GetSubgraphSessionStateMap();
const auto& if_node_session_states = subgraph_session_states.at(if_index_1);
const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch");
const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch");

auto if_node_branches_prepack_counter_1 =
session_state_1_then_branch_session_state.GetNumberOfPrepacksCounter() +
session_state_1_else_branch_session_state.GetNumberOfPrepacksCounter();

// We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph)
ASSERT_EQ(if_node_branches_prepack_counter_1, static_cast<size_t>(2));

auto if_node_branches_shared_prepack_counter_1 =
session_state_1_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() +
session_state_1_else_branch_session_state.GetUsedSharedPrePackedWeightCounter();

// We should only be seeing 1 shared pre-pack weights usage in the "If" node
// Either the "then branch" or "else branch" will be using the shared version
// depending on which branch writes to the shared container
ASSERT_EQ(if_node_branches_shared_prepack_counter_1, static_cast<size_t>(1));

// Second session/model
Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, std::vector<ONNX_NAMESPACE::FunctionProto>(),
DefaultLoggingManager().DefaultLogger());

CreateGraphWithSubgraph(model_2.MainGraph());
PlaceAllNodesToCPUEP(model_2.MainGraph());
SessionState session_state_2(model_2.MainGraph(),
execution_providers,
tp.get(),
nullptr, /*inter_op_thread_pool*/
dtm,
DefaultLoggingManager().DefaultLogger(),
profiler,
sess_options,
&prepacked_weights_container);

ASSERT_STATUS_OK(session_state_2.FinalizeSessionState(std::basic_string<PATH_CHAR_TYPE>(),
kernel_registry_manager));

// At the main graph level, there should be no pre-packing calls as there are
// no initializers (shared or otherwise) consumed by any nodes in the main graph
ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast<size_t>(0));

auto if_index_2 = 1;
if (session_state_2.GetKernel(0)->Node().OpType() == "If") {
if_index_2 = 0;
}

const auto& subgraph_session_states_2 = session_state_2.GetSubgraphSessionStateMap();
const auto& if_node_session_states_2 = subgraph_session_states_2.at(if_index_2);
const auto& session_state_2_then_branch_session_state = *if_node_session_states_2.at("then_branch");
const auto& session_state_2_else_branch_session_state = *if_node_session_states_2.at("else_branch");

auto if_node_branches_prepack_counter_2 =
session_state_2_then_branch_session_state.GetNumberOfPrepacksCounter() +
session_state_2_else_branch_session_state.GetNumberOfPrepacksCounter();

// We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph)
ASSERT_EQ(if_node_branches_prepack_counter_2, static_cast<size_t>(2));

auto if_node_branches_shared_prepack_counter_2 =
session_state_2_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() +
session_state_2_else_branch_session_state.GetUsedSharedPrePackedWeightCounter();

// We should be seeing 2 shared pre-pack weights calls in the "If" node
// Both branches will be using the shared version coming from the first model.
ASSERT_EQ(if_node_branches_shared_prepack_counter_2, static_cast<size_t>(2));
}

INSTANTIATE_TEST_SUITE_P(SessionStateTests,
SessionStatePrepackingTest,
testing::Values(PrepackingTestParam{false, false},
Expand Down

0 comments on commit 614af37

Please sign in to comment.