Skip to content

Commit

Permalink
[QNN] MatMulAddFusion and Reshape Related Fusion (#22494)
Browse files Browse the repository at this point in the history
QNN EP relies on Gemm Op to use FullyConnected QNN Op to run the model,
which is much faster than MatMul+Add. This PR fuses MatMul+Add when
MatMul's 2nd input is 2D initializer, no matter the rank of the 1st
input. If the 1st input is not 2D tensor, Reshape nodes will be added.

On QNN EP, the memory allocation is for each activation tensor, so
Reshape/Squeeze/Unsqueeze is not no-op. This PR also add some fusion
trying to remove redundant reshape nodes. For some QNN AI Hub models on
specific device, without removing the Reshape nodes, it cannot finalize
the graph when execution, but works well after removing.

Run below models with and without the change:
swin_tiny: Average inference time cost: 12.8077 ms | Average inference
time cost: 23.956 ms
swin_base: Average inference time cost: 27.0639 ms | Average inference
time cost: 57.6608 ms
convnext_tiny: Average inference time cost: 3.42956 ms | Average
inference time cost: 16.1848 ms
openai_clip_CLIPTextEncoder: Average inference time cost: 5.96104 ms |
Average inference time cost: 220.406 ms
openai_clip_CLIPImageEncoder: Average inference time cost: 41.8206 ms |
Average inference time cost: 919.712 ms

NOTE that current change skips the Attention pattern because it not it
will cause AttentionFusion to work. Ideally we need to adjust the
AttentionFusion to support the Gemm pattern, but it requires big
changes. Maybe we can do this in the future, say, when we want to run
transformer models on QNN, since we don't have Attention QNN, we still
want to fuse MatMul+Add in the Attention pattern to use FullyConnected
in QNN side.

---------

Co-authored-by: adrianlizarraga <[email protected]>
  • Loading branch information
centwang and adrianlizarraga authored Feb 18, 2025
1 parent 60d25b2 commit 03c6c2e
Show file tree
Hide file tree
Showing 33 changed files with 943 additions and 393 deletions.
178 changes: 153 additions & 25 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,63 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace {

// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
// if we go to other MatMul-Add pairs in the same pattern.
struct AttentionPatternCache {
bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) {
const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) {
return true;
}

if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
unsigned int shape_count = 0;
const Node* ln_add_node = nullptr;
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
std::string op_type = (*it).OpType();
if (op_type == "Add") {
ln_add_node = &(*it);
add_count++;
} else if (op_type == "MatMul") {
matmul_count++;
} else if (op_type == "Shape") {
shape_count++;
}
}

if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
if (attn_add_node && attn_add_node->OpType() == "Add") {
attn_ln_nodes.insert(parent_node);
attn_add_nodes.insert(attn_add_node);
return true;
}
}
}

return false;
}

std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;
};

} // namespace

Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

// Cache for skipping Attention subgraph pattern.
AttentionPatternCache attn_pattern_cache;

for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr)
Expand Down Expand Up @@ -65,58 +118,133 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Gemm only support Matrix, need to check the shape of MatMul and Add
auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) {
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape->dim_size() != 2) {
continue;
}

if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
bool need_reshape = matmul_a_shape->dim_size() != 2;
const auto& dim_n = matmul_b_shape->dim(1);
InlinedVector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Only check and skip Attention pattern here because normally input to Attention is 4D.
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
continue;
}

// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);
if (a_shape.Size() == -1) {
continue;
}

const auto& dim_k = matmul_b_shape->dim(0);
if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) {
continue;
}

shape_values = a_shape.AsShapeVector();
// If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval.
m = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
k = dim_k.dim_value();
n = dim_n.dim_value();
}

const auto& matmul_output = *matmul_node.OutputDefs()[0];

auto matmul_output_name = matmul_output.Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
gemm_input_defs.push_back(add_input_defs[0]);
}
int bias_idx = matmul_output_name == add_input_defs[0]->Name() ? 1 : 0;
gemm_input_defs.push_back(add_input_defs[bias_idx]);

// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
// GEMM only supports unidirectional broadcast on the bias input C
if (!gemm_input_defs.back()->Shape()) {
continue;
}
const auto& bias_shape = *gemm_input_defs.back()->Shape();
const auto& M = matmul_output.Shape()->dim()[0];
const auto& N = matmul_output.Shape()->dim()[1];
auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) {
return dim.has_dim_value() && dim.dim_value() == 1;
};

bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim()[0] == N) ||
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim()[0]) && bias_shape.dim()[1] == N) ||
(bias_shape.dim_size() == 2 && bias_shape.dim()[0] == M &&
(dim_has_value_1(bias_shape.dim()[1]) || bias_shape.dim()[1] == N)));
bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) ||
(!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) &&
bias_shape.dim(1) == dim_n) ||
(!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) &&
(dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n)));
if (!valid) {
continue;
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
{});
auto gemm_output_defs = add_node.MutableOutputDefs();
Node* input_node = nullptr;
Node* output_node = nullptr;
if (need_reshape) {
auto add_reshape = [&](const InlinedVector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {
const std::string name = is_input ? "gemm_input" : "gemm_output";
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape.data(), shape.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
ONNX_NAMESPACE::TypeProto new_arg_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type());
new_arg_type.mutable_tensor_type()->set_elem_type(element_type);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(m);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(is_input ? k : n);
NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name,
{is_input ? gemm_input_defs[0] : new_arg, shape_arg},
{is_input ? new_arg : gemm_output_defs[0]});
reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
return &reshape_node;
};

input_node = add_reshape({m, k}, graph, true);
gemm_input_defs[0] = input_node->MutableOutputDefs()[0];
shape_values.back() = n;
output_node = add_reshape(shape_values, graph, false);
gemm_output_defs[0] = output_node->MutableInputDefs()[0];
}

// Assign provider to this new node. Provider should be same as the provider for old node.
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm",
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());

// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
if (need_reshape) {
graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0);
graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0);
} else {
input_node = &gemm_node;
output_node = &gemm_node;
}

auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node);
for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 0) {
graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0);
} else if (cur->dst_arg_index == 1) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1);
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges);
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == bias_idx) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2);
break;
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0);
graph.RemoveNode(matmul_node.Index());
graph.RemoveNode(add_node.Index());

modified = true;
}
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
fused_count++;
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
modified = true;
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) {
modified = true;
}
}

Expand Down Expand Up @@ -452,4 +454,53 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
return true;
}

bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
InlinedVector<int64_t> shape_value;
while (true) {
Node& curr_node = contiguous_reshapes.back();
if (graph.NodeProducesGraphOutput(curr_node) || curr_node.GetOutputEdgesCount() != 1) {
break;
}

Node* next_node = graph.GetNode(curr_node.OutputNodesBegin()->Index());
if (next_node->OpType() != "Reshape" && next_node->OpType() != "Squeeze" && next_node->OpType() != "Unsqueeze") {
break;
}

auto shape = next_node->OutputDefs()[0]->Shape();
if (!shape) {
break;
}

auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape);
if (tensor_shape.Size() == -1) {
break;
}

shape_value = tensor_shape.AsShapeVector();
contiguous_reshapes.emplace_back(*next_node);
}

if (contiguous_reshapes.size() < 2) {
return false;
}

const std::string& name = contiguous_reshapes[0].get().Name();
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
{contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back().get().MutableOutputDefs()[0]});
reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType());

graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node);

return true;
}

} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer {
static bool Is_One_Element_Input(const Node& cur_node, int index);
static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat,
int index, gsl::span<const int64_t> shape_value, const logging::Logger& logger);

// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
// so this fusion can help to reduce memory usage on such devices.
static bool FuseContiguousReshapes(Node& reshape, Graph& graph);
};

} // namespace onnxruntime
Loading

0 comments on commit 03c6c2e

Please sign in to comment.