diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index b2e7ef0b4f558..48df511d0c672 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -4,6 +4,7 @@ #include "common_subexpression_elimination.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" +#include "core/framework/tensorprotoutils.h" #include #include @@ -170,6 +171,32 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) { std::equal(lhs.begin(), lhs.end(), rhs.begin()); } +// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op. +// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto. +bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) { + if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() && + (lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT || + lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 || + lhs_t.data_type() == onnx::TensorProto_DataType_INT64) && + lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 && + utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) { + return false; + } + const void* lhs_value = lhs_t.raw_data().data(); + const void* rhs_value = rhs_t.raw_data().data(); + switch (lhs_t.data_type()) { + case onnx::TensorProto_DataType_FLOAT: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + case onnx::TensorProto_DataType_FLOAT16: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + case onnx::TensorProto_DataType_INT64: + return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); + default: + break; + } + return false; +} + bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) { if (&lhs == &rhs) { return true; @@ -193,6 +220,7 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A case onnx::AttributeProto_AttributeType_STRINGS: return AreRangesEqual(lhs.strings(), rhs.strings()); case onnx::AttributeProto_AttributeType_TENSOR: + return AreScalarTensorAttributeEqual(lhs.t(), rhs.t()); case onnx::AttributeProto_AttributeType_GRAPH: case onnx::AttributeProto_AttributeType_SPARSE_TENSOR: case onnx::AttributeProto_AttributeType_TYPE_PROTO: @@ -207,6 +235,31 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A return false; } +// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto. +std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) { + std::size_t hash = 0; + if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) { + int data_type = attr_t.data_type(); + switch (data_type) { + case onnx::TensorProto_DataType_FLOAT: + UpdateHash(data_type, hash); + UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); + break; + case onnx::TensorProto_DataType_FLOAT16: + UpdateHash(data_type, hash); + UpdateHash(static_cast(*reinterpret_cast(attr_t.raw_data().data())), hash); + break; + case onnx::TensorProto_DataType_INT64: + UpdateHash(data_type, hash); + UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); + break; + default: + break; + } + } + return hash; +} + std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) { std::size_t hash = 0; UpdateHash( @@ -233,6 +286,8 @@ std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) { UpdateHashWithContainer(attr.strings(), hash); break; case onnx::AttributeProto_AttributeType_TENSOR: + UpdateHash(attr.t(), &GetTensorAttributeHash, hash); + break; case onnx::AttributeProto_AttributeType_GRAPH: case onnx::AttributeProto_AttributeType_SPARSE_TENSOR: case onnx::AttributeProto_AttributeType_TYPE_PROTO: diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f319e7254568d..63612c47f9c56 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -69,6 +69,7 @@ #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rocm_blas_alt_impl.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer.h" @@ -211,9 +212,9 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for - // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by - // default, CSE will not merge them, because the different initializers are represented by different NodeArg. + // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create + // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output + // or consume different initializers with same value, by default, CSE will not merge them. InlinedHashSet excluded_initializers; excluded_initializers.reserve(session_options.initializers_to_share_map.size()); for (const auto& p : session_options.initializers_to_share_map) { @@ -221,7 +222,7 @@ InlinedVector> GenerateTransformers( } const InlinedHashSet no_limit_empty_ep_list = {}; transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); - + transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, session_options.config_options)); diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc new file mode 100644 index 0000000000000..9f20520e3e3f4 --- /dev/null +++ b/onnxruntime/core/optimizer/shape_input_merge.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/shape_input_merge.h" + +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +namespace { +std::string GetShapeString(const NodeArg* input_arg) { + auto shape = input_arg->Shape(); + if (!shape) return ""; + std::stringstream ss; + ss << "["; + for (int i = 0; i < shape->dim_size(); ++i) { + if (i != 0) ss << ","; + auto dim = shape->dim(i); + if (dim.has_dim_value()) { + ss << std::to_string(dim.dim_value()); + } else if (dim.has_dim_param()) { + ss << "'" << dim.dim_param() << "'"; + } else { + return ""; + } + } + ss << "]"; + return ss.str(); +} + +} // namespace + +Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedHashMap> input_hash_to_nodes; + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (!p_node) continue; // we removed the node as part of an earlier fusion + ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger)); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) || + !graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) { + continue; + } + std::string shape_str = GetShapeString(p_node->InputDefs()[0]); + if (shape_str.empty()) continue; + if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) { + input_hash_to_nodes[shape_str] = InlinedVector(); + } + input_hash_to_nodes[shape_str].emplace_back(p_node); + } + + // All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input. + for (auto& kv : input_hash_to_nodes) { + if (kv.second.size() < 2) continue; + NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0]; + bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg); + for (size_t i = 1; i < kv.second.size(); ++i) { + Node* p_node = kv.second[i]; + const NodeArg* input_arg = p_node->InputDefs()[0]; + if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue; + if (!graph.IsInputsIncludingInitializers(input_arg)) { + const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin(); + graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0); + } + graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg); + if (!is_first_input_arg_graph_input) { + const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin(); + graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0); + } + modified = true; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_input_merge.h b/onnxruntime/core/optimizer/shape_input_merge.h new file mode 100644 index 0000000000000..5cb943998487b --- /dev/null +++ b/onnxruntime/core/optimizer/shape_input_merge.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class ShapeInputMerge +Merge all shape inputs having same shape value to a single shape input. +This change will not affect the performance, but it open chances for CSE fusion to merge nodes. +*/ +class ShapeInputMerge : public GraphTransformer { + public: + ShapeInputMerge(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("ShapeInputMerge", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 7c3599a08ec7a..7055882961e17 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -272,7 +272,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) { // We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain), // as long as we verify which of their operations are non-deterministic and add them in the map below. constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike", - "RandomNormalLike", "Multinomial"}; + "RandomNormalLike", "Multinomial", "Dropout"}; // List of deterministic MS domain operators. Currently used for constant folding and common subexpression elimination. // @@ -280,7 +280,8 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm // with the above ONNX list. With the current approach, only MS domain Q/DQ operators // (plus ShrunkenGather for training) are considered deterministic. #ifdef ENABLE_TRAINING_OPS -constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear"}; +constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear", + "ConcatTraining"}; #else constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"}; #endif diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1535e2b60a3bd..97f1feaaa612d 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -65,6 +65,7 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" @@ -4879,6 +4880,53 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) { ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1); } +TEST_F(GraphTransformationTests, CseWithConstantOfShape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> input_shape; + input_shape.reserve(4); + input_shape.emplace_back("dim0"); + input_shape.emplace_back(512); + input_shape.emplace_back(16); + input_shape.emplace_back("dim3"); + auto* input_arg = builder.MakeSymbolicInput(input_shape); + auto* shape_out_1 = builder.MakeIntermediate(); + auto* shape_out_2 = builder.MakeIntermediate(); + auto* constant_of_shape_out_1 = builder.MakeIntermediate(); + auto* constant_of_shape_out_2 = builder.MakeIntermediate(); + auto* mul_out_1 = builder.MakeIntermediate(); + auto* mul_out_2 = builder.MakeOutput(); + builder.AddNode("Shape", {input_arg}, {shape_out_1}); + builder.AddNode("Shape", {input_arg}, {shape_out_2}); + TensorProto value_tensor; + value_tensor.add_dims(1); + float value = 2.333f; + value_tensor.set_raw_data(reinterpret_cast(&value), sizeof(float)); + value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + builder.AddNode("ConstantOfShape", {shape_out_1}, {constant_of_shape_out_1}).AddAttribute("value", value_tensor); + builder.AddNode("ConstantOfShape", {shape_out_2}, {constant_of_shape_out_2}).AddAttribute("value", value_tensor); + builder.AddNode("Mul", {input_arg, constant_of_shape_out_1}, {mul_out_1}); + builder.AddNode("Mul", {mul_out_1, constant_of_shape_out_2}, {mul_out_2}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Shape"] == 2); + TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 2); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Shape"] == 1); + TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, QuickGelu) { // Sigmoid(x*alpha)*x, float { @@ -7543,5 +7591,79 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, ShapeInputMerge) { + auto build_test_case = [&](ModelTestBuilder& builder) { + std::vector> input_shape; + input_shape.reserve(5); + input_shape.emplace_back("dim0"); + input_shape.emplace_back(512); + input_shape.emplace_back(1); + input_shape.emplace_back(1536); + input_shape.emplace_back("dim4"); + auto* input_arg = builder.MakeSymbolicInput(input_shape); + auto* neg_out = builder.MakeIntermediate(); + auto* axes_initializer = builder.MakeInitializer({1}, {static_cast(2)}); + auto* squeeze_out = builder.MakeIntermediate(); + auto* cast_out = builder.MakeIntermediate(); + auto* unsqueeze_out = builder.MakeOutput(); + auto* shape_1_out = builder.MakeOutput(); + auto* shape_2_out = builder.MakeOutput(); + auto* shape_3_out = builder.MakeOutput(); + auto* shape_4_out = builder.MakeOutput(); + auto* shape_5_out = builder.MakeOutput(); + builder.AddNode("Neg", {input_arg}, {neg_out}); + builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out}); + builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast(10)); + builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out}); + builder.AddNode("Shape", {input_arg}, {shape_1_out}); + builder.AddNode("Shape", {neg_out}, {shape_2_out}); + builder.AddNode("Shape", {squeeze_out}, {shape_3_out}); + builder.AddNode("Shape", {cast_out}, {shape_4_out}); + builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + InlinedHashMap ref_count; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Shape") { + std::string name = node.InputDefs()[0]->Name(); + if (ref_count.find(name) == ref_count.end()) { + ref_count[name] = 1; + } else { + ref_count[name]++; + } + } + } + TEST_RETURN_IF_NOT(ref_count.size() == 5); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + InlinedHashMap ref_count; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Shape") { + std::string name = node.InputDefs()[0]->Name(); + if (ref_count.find(name) == ref_count.end()) { + ref_count[name] = 1; + } else { + ref_count[name]++; + } + } + } + TEST_RETURN_IF_NOT(ref_count.size() == 2); + int sum = 0, mul = 1; + for (auto& entry : ref_count) { + sum += entry.second; + mul *= entry.second; + } + TEST_RETURN_IF_NOT(sum == 5 && mul == 6); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 5d527369a1b75..9ce88e549eed2 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -44,6 +44,7 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/shape_input_merge.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -115,10 +116,11 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); #endif - // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for - // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by - // default, CSE will not merge them, because the different initializers are represented by different NodeArg. + // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create + // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output + // or consume different initializers with same value, by default, CSE will not merge them. transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input. transformers.emplace_back(std::make_unique(compatible_eps)); // Remove duplicate nodes. Must be applied before any recompute transformations.