Skip to content

Commit

Permalink
Some Shape Related Fusions (microsoft#19832)
Browse files Browse the repository at this point in the history
This PR adds below shape related fusions, which is helpful for some
transformer models:
- ShapeInputMerge is to merge all Shape nodes' input NodeArg to a single
one (the 1st one on topo order) if they have the same shape value. This
helps CSE fusion to merge more nodes.
- CSE fusion to support scalar tensor as attribute value. This is mainly
to support ConstantOfShape node.
  • Loading branch information
centwang authored Mar 12, 2024
1 parent 978c40d commit 0c078df
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 9 deletions.
55 changes: 55 additions & 0 deletions onnxruntime/core/optimizer/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common_subexpression_elimination.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"
#include "core/framework/tensorprotoutils.h"

#include <memory>
#include <type_traits>
Expand Down Expand Up @@ -170,6 +171,32 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
std::equal(lhs.begin(), lhs.end(), rhs.begin());
}

// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
(lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 &&
utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) {
return false;
}
const void* lhs_value = lhs_t.raw_data().data();
const void* rhs_value = rhs_t.raw_data().data();
switch (lhs_t.data_type()) {
case onnx::TensorProto_DataType_FLOAT:
return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
case onnx::TensorProto_DataType_FLOAT16:
return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
case onnx::TensorProto_DataType_INT64:
return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
default:
break;
}
return false;
}

bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
if (&lhs == &rhs) {
return true;
Expand All @@ -193,6 +220,7 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
case onnx::AttributeProto_AttributeType_STRINGS:
return AreRangesEqual(lhs.strings(), rhs.strings());
case onnx::AttributeProto_AttributeType_TENSOR:
return AreScalarTensorAttributeEqual(lhs.t(), rhs.t());
case onnx::AttributeProto_AttributeType_GRAPH:
case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
case onnx::AttributeProto_AttributeType_TYPE_PROTO:
Expand All @@ -207,6 +235,31 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
return false;
}

// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto.
std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
std::size_t hash = 0;
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
int data_type = attr_t.data_type();
switch (data_type) {
case onnx::TensorProto_DataType_FLOAT:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
break;
case onnx::TensorProto_DataType_FLOAT16:
UpdateHash(data_type, hash);
UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
break;
case onnx::TensorProto_DataType_INT64:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
break;
default:
break;
}
}
return hash;
}

std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
std::size_t hash = 0;
UpdateHash(
Expand All @@ -233,6 +286,8 @@ std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
UpdateHashWithContainer(attr.strings(), hash);
break;
case onnx::AttributeProto_AttributeType_TENSOR:
UpdateHash(attr.t(), &GetTensorAttributeHash, hash);
break;
case onnx::AttributeProto_AttributeType_GRAPH:
case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
case onnx::AttributeProto_AttributeType_TYPE_PROTO:
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rocm_blas_alt_impl.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/transpose_optimizer.h"
Expand Down Expand Up @@ -211,17 +212,17 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
}

// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
// Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
// more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
// or consume different initializers with same value, by default, CSE will not merge them.
InlinedHashSet<std::string> excluded_initializers;
excluded_initializers.reserve(session_options.initializers_to_share_map.size());
for (const auto& p : session_options.initializers_to_share_map) {
excluded_initializers.insert(p.first);
}
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));

transformers.emplace_back(std::make_unique<ShapeInputMerge>());
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
Expand Down
78 changes: 78 additions & 0 deletions onnxruntime/core/optimizer/shape_input_merge.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/shape_input_merge.h"

#include "core/graph/graph_utils.h"

namespace onnxruntime {

namespace {
std::string GetShapeString(const NodeArg* input_arg) {
auto shape = input_arg->Shape();
if (!shape) return "";
std::stringstream ss;
ss << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (i != 0) ss << ",";
auto dim = shape->dim(i);
if (dim.has_dim_value()) {
ss << std::to_string(dim.dim_value());
} else if (dim.has_dim_param()) {
ss << "'" << dim.dim_param() << "'";
} else {
return "";
}
}
ss << "]";
return ss.str();
}

} // namespace

Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedHashMap<std::string, InlinedVector<Node*>> input_hash_to_nodes;
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (!p_node) continue; // we removed the node as part of an earlier fusion
ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) ||
!graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) {
continue;
}
std::string shape_str = GetShapeString(p_node->InputDefs()[0]);
if (shape_str.empty()) continue;
if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) {
input_hash_to_nodes[shape_str] = InlinedVector<Node*>();
}
input_hash_to_nodes[shape_str].emplace_back(p_node);
}

// All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input.
for (auto& kv : input_hash_to_nodes) {
if (kv.second.size() < 2) continue;
NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0];
bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg);
for (size_t i = 1; i < kv.second.size(); ++i) {
Node* p_node = kv.second[i];
const NodeArg* input_arg = p_node->InputDefs()[0];
if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue;
if (!graph.IsInputsIncludingInitializers(input_arg)) {
const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin();
graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0);
}
graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg);
if (!is_first_input_arg_graph_input) {
const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin();
graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0);
}
modified = true;
}
}

return Status::OK();
}

} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/optimizer/shape_input_merge.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
@Class ShapeInputMerge
Merge all shape inputs having same shape value to a single shape input.
This change will not affect the performance, but it open chances for CSE fusion to merge nodes.
*/
class ShapeInputMerge : public GraphTransformer {
public:
ShapeInputMerge(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("ShapeInputMerge", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,16 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) {
// We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain),
// as long as we verify which of their operations are non-deterministic and add them in the map below.
constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike",
"RandomNormalLike", "Multinomial"};
"RandomNormalLike", "Multinomial", "Dropout"};

// List of deterministic MS domain operators. Currently used for constant folding and common subexpression elimination.
//
// TODO(adrianlizarraga): Investigate converting to lists of *non-deterministic* MS domain operators to be consistent
// with the above ONNX list. With the current approach, only MS domain Q/DQ operators
// (plus ShrunkenGather for training) are considered deterministic.
#ifdef ENABLE_TRAINING_OPS
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear"};
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
"ConcatTraining"};
#else
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
#endif
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_input_merge.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/utils.h"
Expand Down Expand Up @@ -4879,6 +4880,53 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
}

TEST_F(GraphTransformationTests, CseWithConstantOfShape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
std::vector<std::variant<int64_t, std::string>> input_shape;
input_shape.reserve(4);
input_shape.emplace_back("dim0");
input_shape.emplace_back(512);
input_shape.emplace_back(16);
input_shape.emplace_back("dim3");
auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
auto* shape_out_1 = builder.MakeIntermediate();
auto* shape_out_2 = builder.MakeIntermediate();
auto* constant_of_shape_out_1 = builder.MakeIntermediate();
auto* constant_of_shape_out_2 = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeIntermediate();
auto* mul_out_2 = builder.MakeOutput();
builder.AddNode("Shape", {input_arg}, {shape_out_1});
builder.AddNode("Shape", {input_arg}, {shape_out_2});
TensorProto value_tensor;
value_tensor.add_dims(1);
float value = 2.333f;
value_tensor.set_raw_data(reinterpret_cast<const char*>(&value), sizeof(float));
value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
builder.AddNode("ConstantOfShape", {shape_out_1}, {constant_of_shape_out_1}).AddAttribute("value", value_tensor);
builder.AddNode("ConstantOfShape", {shape_out_2}, {constant_of_shape_out_2}).AddAttribute("value", value_tensor);
builder.AddNode("Mul", {input_arg, constant_of_shape_out_1}, {mul_out_1});
builder.AddNode("Mul", {mul_out_1, constant_of_shape_out_2}, {mul_out_2});
};

auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Shape"] == 2);
TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 2);
return Status::OK();
};

auto post_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Shape"] == 1);
TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 1);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<CommonSubexpressionElimination>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

TEST_F(GraphTransformationTests, QuickGelu) {
// Sigmoid(x*alpha)*x, float
{
Expand Down Expand Up @@ -7543,5 +7591,79 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
}
}

TEST_F(GraphTransformationTests, ShapeInputMerge) {
auto build_test_case = [&](ModelTestBuilder& builder) {
std::vector<std::variant<int64_t, std::string>> input_shape;
input_shape.reserve(5);
input_shape.emplace_back("dim0");
input_shape.emplace_back(512);
input_shape.emplace_back(1);
input_shape.emplace_back(1536);
input_shape.emplace_back("dim4");
auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
auto* neg_out = builder.MakeIntermediate();
auto* axes_initializer = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* squeeze_out = builder.MakeIntermediate();
auto* cast_out = builder.MakeIntermediate();
auto* unsqueeze_out = builder.MakeOutput();
auto* shape_1_out = builder.MakeOutput();
auto* shape_2_out = builder.MakeOutput();
auto* shape_3_out = builder.MakeOutput();
auto* shape_4_out = builder.MakeOutput();
auto* shape_5_out = builder.MakeOutput();
builder.AddNode("Neg", {input_arg}, {neg_out});
builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out});
builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast<int64_t>(10));
builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out});
builder.AddNode("Shape", {input_arg}, {shape_1_out});
builder.AddNode("Shape", {neg_out}, {shape_2_out});
builder.AddNode("Shape", {squeeze_out}, {shape_3_out});
builder.AddNode("Shape", {cast_out}, {shape_4_out});
builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out});
};

auto pre_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 5);
return Status::OK();
};

auto post_graph_checker = [&](Graph& graph) {
InlinedHashMap<std::string, int> ref_count;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Shape") {
std::string name = node.InputDefs()[0]->Name();
if (ref_count.find(name) == ref_count.end()) {
ref_count[name] = 1;
} else {
ref_count[name]++;
}
}
}
TEST_RETURN_IF_NOT(ref_count.size() == 2);
int sum = 0, mul = 1;
for (auto& entry : ref_count) {
sum += entry.second;
mul *= entry.second;
}
TEST_RETURN_IF_NOT(sum == 5 && mul == 6);
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<ShapeInputMerge>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit 0c078df

Please sign in to comment.