Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Nov 25, 2024
1 parent 8606668 commit 47d4755
Show file tree
Hide file tree
Showing 17 changed files with 125 additions and 152 deletions.
47 changes: 23 additions & 24 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

// These two sets are used to skip Attention pattern, which will be handled by AttentionFusion.
// There are 4 MatMul-Add pairs in Attention pattern, 3 of them are following LayerNormalization, the other one
// produces output which is added with LayerNormalization's output, we can skip them directly if we see same
// processed nodes again which are stored in these two sets.
std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;

Expand Down Expand Up @@ -74,15 +78,16 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,

bool need_reshape = matmul_a_shape->dim_size() != 2;
const auto& dim_n = matmul_b_shape->dim(1);
std::vector<int64_t> shape_values;
InlinedVector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Skip Attention pattern, AttentionFusion will handle it. In such case, there are 4 MatMul-Add pairs,
// 3 of them are following LN, the other one produces output which is added by LN's output.
// 3 of them are following LN, the other one produces output which is added with LN's output.
const Node* parent_node = graph.GetProducerNode(matmul_input_defs[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&next_node) > 0) {
continue;
}

if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
Expand All @@ -99,6 +104,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
shape_count++;
}
}

if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
Expand All @@ -112,39 +118,29 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,

// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
bool is_concrete_shape = true;
for (int i = 0; i < matmul_a_shape->dim_size(); ++i) {
const auto& dim = matmul_a_shape->dim(i);
if (!utils::HasDimValue(dim)) {
is_concrete_shape = false;
break;
}
shape_values.emplace_back(dim.dim_value());
}
if (!is_concrete_shape) {
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);
if (a_shape.Size() == -1) {
continue;
}

const auto& dim_k = matmul_b_shape->dim(0);
if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) {
continue;
}

shape_values = a_shape.AsShapeVector();
// If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval.
m = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
k = dim_k.dim_value();
n = dim_n.dim_value();
m = std::accumulate(shape_values.begin(), shape_values.end() - 1, static_cast<int64_t>(1),
std::multiplies<int64_t>());
}

const auto& matmul_output = *matmul_node.OutputDefs()[0];

auto matmul_output_name = matmul_output.Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
gemm_input_defs.push_back(add_input_defs[0]);
}
int bias_idx = matmul_output_name == add_input_defs[0]->Name() ? 1 : 0;
gemm_input_defs.push_back(add_input_defs[bias_idx]);

// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
// GEMM only supports unidirectional broadcast on the bias input C
Expand All @@ -169,7 +165,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node* input_node = nullptr;
Node* output_node = nullptr;
if (need_reshape) {
auto add_reshape = [&](const std::vector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {
auto add_reshape = [&](const InlinedVector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {
const std::string name = is_input ? "gemm_input" : "gemm_output";
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
Expand Down Expand Up @@ -198,7 +194,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
gemm_output_defs[0] = output_node->MutableInputDefs()[0];
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm",
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm",
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());

Expand All @@ -218,13 +214,16 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1);
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges);
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 1) {
if (cur->dst_arg_index == bias_idx) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2);
break;
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0);
Expand Down
38 changes: 13 additions & 25 deletions onnxruntime/core/optimizer/reshape_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,21 +456,20 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo

bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger) {
ORT_UNUSED_PARAMETER(logger);
InlinedVector<Node*> contiguous_reshapes{&reshape};
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
InlinedVector<int64_t> shape_value;
while (true) {
Node* p_curr_node = contiguous_reshapes.back();
if (graph.NodeProducesGraphOutput(*p_curr_node) || p_curr_node->GetOutputEdgesCount() != 1) {
Node& curr_node = contiguous_reshapes.back();
if (graph.NodeProducesGraphOutput(curr_node) || curr_node.GetOutputEdgesCount() != 1) {
break;
}

Node* p_next_node = graph.GetNode(p_curr_node->OutputNodesBegin()->Index());
if (p_next_node->OpType() != "Reshape" && p_next_node->OpType() != "Squeeze" &&
p_next_node->OpType() != "Unsqueeze") {
Node* next_node = graph.GetNode(curr_node.OutputNodesBegin()->Index());
if (next_node->OpType() != "Reshape" && next_node->OpType() != "Squeeze" && next_node->OpType() != "Unsqueeze") {
break;
}

auto shape = p_next_node->OutputDefs()[0]->Shape();
auto shape = next_node->OutputDefs()[0]->Shape();
if (!shape) {
break;
}
Expand All @@ -488,37 +487,26 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const lo
break;
}

contiguous_reshapes.emplace_back(p_next_node);
contiguous_reshapes.emplace_back(*next_node);
}

if (contiguous_reshapes.size() < 2) {
return false;
}

const std::string& name = contiguous_reshapes[0]->Name();
const std::string& name = contiguous_reshapes[0].get().Name();
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
{contiguous_reshapes[0]->MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back()->MutableOutputDefs()[0]});
reshape_node.SetExecutionProviderType(contiguous_reshapes[0]->GetExecutionProviderType());

auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*contiguous_reshapes[0]);
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 0) {
graph.AddEdge(cur->src_node, reshape_node.Index(), cur->src_arg_index, 0);
}
}
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
graph_utils::ReplaceDownstreamNodeInput(graph, *contiguous_reshapes.back(), 0, reshape_node, 0);
for (Node* p_node : contiguous_reshapes) {
graph_utils::RemoveNodeOutputEdges(graph, *p_node);
graph.RemoveNode(p_node->Index());
}
{contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back().get().MutableOutputDefs()[0]});
reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType());

graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node);

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
std::vector<uint8_t> unpacked_tensor;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, shape), "Cannot get shape");
uint32_t shape_rank = shape[0];
std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
const int64_t* shape_data_int64 = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
std::vector<uint32_t> input_shape(shape_rank, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
std::vector<uint8_t> unpacked_tensor;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name);
if (1 == input_trans_flag.at(input_i)) {
ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose<size_t>(std::vector<size_t>({1, 0})));
ORT_RETURN_IF_ERROR(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap
const auto& pads_input_name = inputs[1].node_arg.Name();

std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(pads_input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(pads_input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
// Onnx Pads are int64, Qnn use uint32
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod
}

// Get axes initializer bytes.
const auto& axes_tensor = qnn_model_wrapper.GetInitializerTensors().at(axes_input_name);
const auto& axes_tensor = qnn_model_wrapper.GetInitializerTensor(axes_input_name);
std::vector<uint8_t> axes_bytes;

ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*axes_tensor, axes_bytes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
if (is_initializer_input) {
std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
size_t tensor_byte_size = unpacked_tensor.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name();

std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(repeats_input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(repeats_input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
// Onnx repeats are int64, Qnn use uint32
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
if (is_initializer_input) {
std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
k = static_cast<uint32_t>(*tensor_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& te
// Fill in initializer info.
tensor_info.is_initializer = IsInitializerInput(name);
if (tensor_info.is_initializer) {
tensor_info.initializer_tensor = GetInitializerTensors().at(name);
tensor_info.initializer_tensor = GetInitializerTensor(name);
}

return Status::OK();
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class QnnModelWrapper {

const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); }

const ONNX_NAMESPACE::TensorProto* GetInitializerTensor(const std::string& tensor_name) const {
return graph_viewer_.GetConstantInitializer(tensor_name, true);
}

bool IsInitializerInput(std::string input_name) const {
return initializer_lookup_.find(input_name) != initializer_lookup_.end();
}
Expand Down
Loading

0 comments on commit 47d4755

Please sign in to comment.