diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index c479b685f9267..e6ffd0d91372b 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -98,14 +99,81 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } -// Use to create a QuantizeLinear or DequantizeLinear node. Does not update output ValueInfo. Adds axis if needed. -static std::unique_ptr MakeQOrDQ(api::GraphRef& graph, std::string_view domain, std::string_view op_type, - std::vector inputs, - std::optional axis) { - std::unique_ptr node = graph.AddNode(op_type, inputs, /* num_outputs */ 1, domain); - // only set if provided and not the default - if (axis && axis != 1) { - node->SetAttributeInt("axis", *axis); +/// +/// Sets an attribute on a node if the attribute value is valid and differs from the default value. +/// +/// Node on which to set the attribute +/// Attribute's name +/// Attribute value to set +/// Default attribute value +static void SetAttrIfNotDefault(api::NodeRef& node, std::string_view attr_name, + std::optional attr_val, int64_t attr_default_val) { + if (attr_val && attr_val != attr_default_val) { + node.SetAttributeInt(attr_name, *attr_val); + } +} + +/// +/// Adds a new QuantizeLinear node to the graph. Does not update the output's ValueInfo data. +/// +/// Graph into which to add the new node +/// Domain for the new node +/// List of input names for the new node +/// Optional 'axis' attribute value +/// Optional 'block_size' attribute value +/// Optional 'output_dtype' attribute value +/// Optional 'saturate' attribute value +/// Reference to the new QuantizeLinear node +static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::string_view domain, + std::vector inputs, + std::optional axis, + std::optional block_size, + std::optional output_dtype, + std::optional saturate) { + std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + + SetAttrIfNotDefault(*node, "axis", axis, 1); + + if (auto opset = graph.Opset(domain); opset) { + const int64_t required_opset_1 = IsOnnxDomain(domain) ? 19 : 1; + const int64_t required_opset_2 = IsOnnxDomain(domain) ? 21 : 1; + + if (*opset >= required_opset_1) { + SetAttrIfNotDefault(*node, "saturate", saturate, 1); + } + + if (*opset >= required_opset_2) { + SetAttrIfNotDefault(*node, "block_size", block_size, 0); + SetAttrIfNotDefault(*node, "output_dtype", output_dtype, 0); + } + } + + return node; +} + +/// +/// Adds a new DequantizeLinear node to the graph. Does not update the output's ValueInfo data. +/// +/// Graph into which to add the new node +/// Domain for the new node +/// List of input names for the new node +/// Optional 'axis' attribute value +/// Optional 'block_size' attribute value +/// Reference to the new DequantizeLinear node +static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std::string_view domain, + std::vector inputs, + std::optional axis, + std::optional block_size) { + std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + + SetAttrIfNotDefault(*node, "axis", axis, 1); + + if (auto opset = graph.Opset(domain); opset) { + const int64_t required_opset = IsOnnxDomain(domain) ? 21 : 1; + + if (*opset >= required_opset) { + SetAttrIfNotDefault(*node, "block_size", block_size, 0); + } } return node; @@ -263,6 +331,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { auto update_dq_axis = scale_shape && !scale_shape->empty(); int64_t axis = dq_node.GetAttributeIntDefault("axis", 1); + // TODO(adrianlizarraga): Also need to update axis if Unsqueeze inserts a 1 before the axis dim. if (update_dq_axis && is_transpose) { // update axis. auto perm = GetPermAttrIfValid(next_node); @@ -281,7 +350,8 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { } // Add Q - auto new_q_node = MakeQOrDQ(graph, dq_domain, "QuantizeLinear", inputs, axis); + auto new_q_node = MakeQuantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size"), + dq_node.GetAttributeInt("output_dtype"), dq_node.GetAttributeInt("saturate")); auto q_node_outputs = new_q_node->Outputs(); // copy value info from the dq input for the type information, and update the shape to match next_node's output @@ -293,7 +363,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { inputs[0] = new_q_node->Outputs()[0]; // Add DQ - auto new_dq_node = MakeQOrDQ(graph, dq_domain, "DequantizeLinear", inputs, axis); + auto new_dq_node = MakeDequantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size")); auto dq_node_outputs = new_dq_node->Outputs(); // straight copy of value info as the type and shape are the same as next_node's output @@ -499,6 +569,51 @@ static std::vector UnsqueezeShape(gsl::span shape, const return new_shape; } +/// +/// Returns a new squeezed shape without the dimensions of value 1 indicated by the given axes. +/// +/// Input shape to squeeze +/// List of integers indicating the dimensions to squeeze +/// New squeezed shape +static std::vector SqueezeShape(gsl::span shape, const std::vector& axes) { + const size_t init_rank = shape.size(); + std::vector pos_axes(axes.begin(), axes.end()); + + // Normalize negative axis values. + for (size_t i = 0; i < pos_axes.size(); i++) { + if (pos_axes[i] < 0) { + pos_axes[i] += static_cast(init_rank); + } + } + + // Sort positive axis values and remove duplicates. + std::sort(pos_axes.begin(), pos_axes.end()); + pos_axes.erase(std::unique(pos_axes.begin(), pos_axes.end()), pos_axes.end()); + + assert(shape.size() >= pos_axes.size()); + + std::vector new_shape; + size_t j = 0; + + for (size_t i = 0; i < shape.size(); i++) { + if (pos_axes.empty() && shape[i] == 1) { + // If axes is empty, skip all shape values equal to 1. + continue; + } + + if ((j < pos_axes.size()) && (i == gsl::narrow_cast(pos_axes[j]))) { + // Skip shape dim if it appears in axes. shape[i] must be 1. + assert(shape[i] == 1); + j++; + continue; + } + + new_shape.push_back(shape[i]); + } + + return new_shape; +} + // Computes new perm for unsqueezed version of a tensor. Unsafe if axes/perm are invalid or have negative values. // New perm reorders non-1 dimensions in the same way and leaves 1-dims from unsqueeze unchanged. // Ex: @@ -2374,6 +2489,377 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return ctx; } +/// +/// Returns true if the transpose optimizer can modify the given node. +/// +/// Optimizer context +/// Node to check +/// True if allowed to modify the given node +static bool CanModifyNode(const OptimizerCtx& ctx, const api::NodeRef& node) { + const auto& node_ep = node.GetExecutionProviderType(); + bool can_modify = false; + + if (node_ep.empty()) { + // Unassigned nodes can always be modified + can_modify = true; + } else if (node_ep == ctx.provider_type) { + // We can also modify if the EP name in provider_type is not empty and the node is assigned to that EP. + can_modify = true; + } + + return can_modify; +} + +/// +/// Try to remove empty DQ -> Q pair that results from moving a Transpose downstream or a Transpose being canceled out. +/// (DQ -> Q -> consumer node) => consumer node +/// +/// Optimizer context +/// QuantizeLinear node +/// True if an empty DQ -> Q was removed +static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { + assert(q_node.OpType() == "QuantizeLinear"); + + // Require a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(q_node.Inputs()[0]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { + return false; + } + + auto& dq_node = *input_node; + std::unique_ptr single_consumer_node; + + // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. + if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && + OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && + CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // connect Q consumer to DQ input + for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { + if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { + single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); + // break; in theory the Q might be providing multiple inputs. + } + } + + // disconnect other nodes and remove + dq_node.SetInput(0, ""); + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); + + return true; + } + + return false; +} + +/// +/// Try to repair a broken QDQ Transpose node unit that is missing the Q at its output. +/// The Transpose could be blocked on the Op inside the QDQ node unit: +/// DQ -> Transpose -> Op -> Q => +/// DQ -> Transpose -> Q[new] -> DQ[new] -> Op -> Q +/// Alternatively, the Transpose could be providing a graph output: +/// DQ -> Transpose -> graph output => +/// DQ -> Transpose -> Q[new] -> DQ[new] -> graph output +/// +/// Optimizer context +/// Transpose node +/// True if the QDQ node unit was repaired +static bool TryFixTransposeMissingQ(OptimizerCtx& ctx, api::NodeRef& transpose_node) { + assert(transpose_node.OpType() == "Transpose"); + + // Require a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { + return false; + } + + auto& dq_node = *input_node; + + // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. + // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. + auto transpose_output = transpose_node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(transpose_output); + if (consumers->nodes.empty()) { + // DQ -> Transpose -> graph output + } else { + if (consumers->nodes.size() > 1) { + // unexpected to have DQ -> Transpose -> multiple consumers + return false; + } + + if (consumers->nodes[0]->OpType() == "QuantizeLinear") { + // already in QDQ node unit + return false; + } + } + + // Add Q -> DQ after the DQ -> Transpose + return MakeQDQNodeUnit(ctx.graph, dq_node); +} + +/// +/// Fixes a Transpose QDQ node unit that is missing the DQ at its input due to the Transpose being blocked on the Q. +/// Inserts a Q -> DQ pair before the sequence Transpose -> Q by using the scale and zp info from the Q node. +/// Before: prev_node -> Transpose -> Q +/// After: prev_node -> Q[new] -> DQ[new] -> Transpose -> Q +/// +/// Transpose node. +/// True if Q -> DQ insertion was successful. +static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_node) { + assert(transpose_node.OpType() == "Transpose"); + auto transpose_input_name = transpose_node.Inputs()[0]; + auto transpose_output_name = transpose_node.Outputs()[0]; + + // Require a Q as the single consumer of this transpose node's output. + std::unique_ptr maybe_q_node; + if (!OutputValueHasSingleConsumerNode(ctx.graph, transpose_node, 0, maybe_q_node) || + maybe_q_node->OpType() != "QuantizeLinear") { + return false; + } + + // Get the node upstream from the Transpose. + auto prev_node = ctx.graph.GetNodeProducingOutput(transpose_input_name); + if (prev_node == nullptr) { + // Transpose consumes a graph input or constant. Skip. + return false; + } + + if (prev_node->OpType() == "DequantizeLinear") { + // Transpose is already in a QDQ node unit. + return false; + } + + auto& q_node = *maybe_q_node; + const auto q_node_inputs = q_node.Inputs(); + + auto transpose_output_consumers = ctx.graph.GetValueConsumers(transpose_output_name); + if (!transpose_output_consumers->comprehensive || transpose_output_consumers->nodes.size() != 1) { + // Q node should be the only consumer for the Transpose. + return false; + } + + auto transpose_input_consumers = ctx.graph.GetValueConsumers(transpose_input_name); + if (transpose_input_consumers->nodes.size() != 1) { + // The transpose node should be the only consumer of its own input. + return false; + } + + const auto q_domain = q_node.Domain(); + const auto scale_input = q_node_inputs[1]; + const auto scale_value_info = ctx.graph.GetValueInfo(scale_input); + std::optional zp_input; + std::optional> zp_value_info; + + auto scale_shape = scale_value_info->Shape(); + if (!scale_shape) { + // Axis potentially needs updating due to the transpose but we don't have the required info to do it. + return false; + } + + if (q_node_inputs.size() > 2) { + zp_input = q_node_inputs[2]; + zp_value_info = ctx.graph.GetValueInfo(zp_input.value()); + } + + // Per-axis quantization if not a scalar (shape is empty for scalar). + // note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, + // so we have to check the shape. + const bool update_axis = scale_shape && !scale_shape->empty(); + int64_t axis = q_node.GetAttributeIntDefault("axis", 1); + + if (update_axis) { + auto perm = GetPermAttrIfValid(transpose_node); + assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid + NormalizeAndValidateAxis(axis, scale_shape->size()); + axis = (*perm)[gsl::narrow_cast(axis)]; // Note: do not invert permutation. + } + + auto transpose_input_shape = ctx.graph.GetValueInfo(transpose_input_name)->Shape(); + + // Setup Q node inputs. + // We don't connect it to the node preceding the Transpose yet as we will move the output of that to the new DQ first. + std::vector inputs = {"", scale_input}; + if (zp_input) { + inputs.push_back(zp_input.value()); + } + + // Add Q + auto new_q_node = MakeQuantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size"), + q_node.GetAttributeInt("output_dtype"), q_node.GetAttributeInt("saturate")); + auto new_q_node_output = new_q_node->Outputs()[0]; + + // Copy value info from the q output for the type information, and update the shape to match Transpose's input + ctx.graph.CopyValueInfo(q_node.Outputs()[0], new_q_node_output); // Q produces same type as the q_node output + auto new_q_node_value_info = ctx.graph.GetValueInfo(new_q_node_output); + new_q_node_value_info->SetShape(transpose_input_shape ? &*transpose_input_shape : nullptr); + + // update input to connect the DQ to the Q we just added. re-use scale and zp. + inputs[0] = new_q_node->Outputs()[0]; + + // Add new DQ. + auto new_dq_node = MakeDequantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size")); + auto new_dq_node_output = new_dq_node->Outputs()[0]; + ctx.graph.CopyValueInfo(transpose_input_name, new_dq_node_output); + + auto prev_node_outputs = prev_node->Outputs(); + size_t prev_node_output_idx = 0; + for (size_t out_idx = 0; out_idx < prev_node_outputs.size(); ++out_idx) { + if (prev_node_outputs[out_idx] == transpose_input_name) { + prev_node_output_idx = out_idx; + break; + } + } + + // move prev_node output to the new DQ node, and connect prev_node with the new Q node + ctx.graph.MoveOutput(*prev_node, prev_node_output_idx, *new_dq_node, 0); + std::string_view new_prev_node_output_name = prev_node->Outputs()[prev_node_output_idx]; + new_q_node->SetInput(0, new_prev_node_output_name); + ctx.graph.CopyValueInfo(new_dq_node_output, new_prev_node_output_name); + + return true; +} + +/// +/// Fixes QDQ node units that may have been left in an invalid state after the core transpose optimization pass. +/// +/// Optimizer context +/// True if the graph was modified +static bool FixQDQNodeUnits(OptimizerCtx& ctx) { + bool changed = false; + + auto graph_nodes = ctx.graph.Nodes(); + for (size_t i = 0; i < graph_nodes.size(); i++) { + auto& node = *graph_nodes[i]; + + if (!CanModifyNode(ctx, node)) { + continue; + } + + std::string_view op_type = node.OpType(); + + if (op_type == "QuantizeLinear") { + if (TryRemoveEmptyDQQ(ctx, node)) { + changed = true; + continue; + } + } else if (op_type == "Transpose") { + if (TryFixTransposeMissingQ(ctx, node)) { + changed = true; + continue; + } + + if (TryFixTransposeMissingDQ(ctx, node)) { + changed = true; + continue; + } + } + } + + return changed; +} + +/// +/// Try to constant fold Transpose or Squeeze nodes if their input is a constant. +/// Returns true if the graph was modified (i.e., at least one of the consumers received a constant-folded value). +/// +/// Optimization context state +/// Squeeze or Transpose node to try to constant-fold +/// True if graph was modified. The node may not have been removed in either case. +static bool TryConstantFoldNode(OptimizerCtx& ctx, api::NodeRef& node) { + std::string_view node_op_type = node.OpType(); + const bool is_transpose = node_op_type == "Transpose"; + const bool is_squeeze = node_op_type == "Squeeze"; + + if (!is_transpose && !is_squeeze) { + return false; + } + + std::string_view node_input_name = node.Inputs()[0]; + auto const_input = ctx.graph.GetLocalConstant(node_input_name); + if (const_input == nullptr) { + // Doesn't have a constant input. Skip. + return false; + } + + std::string_view node_output_name = node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(node_output_name); + + if (consumers->nodes.empty()) { + // No consumers Skip. + return false; + } + + std::string_view new_initializer_name; + + // Create new squeezed or transposed initializer. + // Once we create this new initializer, we're committed to modifying the graph. + if (is_transpose) { + std::optional> perm = GetPermAttrIfValid(node); + if (perm == std::nullopt) { + // Invalid transpose perm attribute. Should not happen. Skip. + return false; + } + + new_initializer_name = ctx.graph.AddInitializer(const_input->DType(), + const_input->Shape(), + const_input->Data()); + ctx.graph.TransposeInitializer(new_initializer_name, *perm); + } else { + assert(is_squeeze); + std::optional> squeeze_axes = ReadFromAttrOrInput(ctx, node, "axes", /*inp_index*/ 1, + /*opset*/ 13); + if (squeeze_axes == std::nullopt) { + // Invalid Squeeze axes value. Should not happen. Skip. + return false; + } + + auto squeezed_shape = SqueezeShape(const_input->Shape(), *squeeze_axes); + new_initializer_name = ctx.graph.AddInitializer(const_input->DType(), + const_input->Shape(), + const_input->Data()); + ctx.graph.ReshapeInitializer(new_initializer_name, squeezed_shape); + } + + // Iterate through consumers and replace their input(s) with the new initializer. + for (auto& consumer : consumers->nodes) { + std::vector inputs = consumer->Inputs(); + + for (size_t input_idx = 0; input_idx < inputs.size(); input_idx++) { + if (inputs[input_idx] == node_output_name) { + consumer->SetInput(input_idx, new_initializer_name); + } + } + } + + // Remove original node if its output is unused. + if (!ctx.graph.HasValueConsumers(node_output_name)) { + ctx.graph.RemoveNode(node); + } + + // Remove old initializer if no longer used. + // Will not happen if this initializer was unsqueezed/transposed in-place for another consumer. + // Will happen if this initializer is a result of a previous constant-folding operation. + // + // Example: shared_const --+--> Transpose --> Squeeze --> Op0 + // | + // +--> Op1 + // + // The first call to TryConstantFoldNode(transpose) does not remove shared_const because it is used by 'Op1'. + // However, the graph becomes: + // transposed_const --> Squeeze --> Op0 --> + // shared_const --> Op1 --> + // + // The subsequent call to TryConstantFoldNode(squeeze) removes transposed_const from the graph, and we end up with: + // transposed_squeezed_const --> Op0 + // shared_const --> Op1 + if (!ctx.graph.HasValueConsumers(node_input_name)) { + ctx.graph.RemoveInitializer(node_input_name); + } + + return true; +} + // Performs optimization. General algorithm: iterate over nodes in topological order. If a node has a transpose // as input, push it through if the transpose cost does not increase and is likely to decrease. OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { @@ -2444,20 +2930,6 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { // Existing nodes assigned to the CPU EP can be modified. // New nodes can be created and are directly assigned to the CPU EP by setting onnxruntime::ApiGraph::new_node_ep_ // - const auto can_modify_node = [&ctx](const api::NodeRef& node) { - const auto& node_ep = node.GetExecutionProviderType(); - bool can_modify = false; - - if (node_ep.empty()) { - // unassigned nodes can always be modified - can_modify = true; - } else if (node_ep == ctx.provider_type) { - // we can also modify if the EP name in provider_type is not empty and the node is assigned to that EP. - can_modify = true; - } - - return can_modify; - }; // Optimize graph. Nodes will be modified during iteration, but nodes are never deleted before we reach them. // New transpose nodes are inserted, but always as an input to an existing node. @@ -2467,7 +2939,7 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { have_dq = true; } - if (!can_modify_node(node)) { + if (!CanModifyNode(ctx, node)) { continue; } @@ -2495,97 +2967,54 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { return result; } - // Run 'fix up' pass for QDQ node units. + // Run constant-folding for Transpose and Squeeze ops to fold sequences like (const --> Squeeze --> DQ) + // into (squeezed_const --> DQ). // - // Repair broken QDQ node unit from Transpose being blocked on Op inside a QDQ node unit. - // DQ -> Transpose -> Op -> Q => - // DQ -> Transpose -> Q -> DQ -> Op -> Q + // These constant-foldable sequences are created when a transpose is pushed through a node that has a shared + // initializer as one of its inputs. The node's inputs must be transposed, and for initializer inputs this transpose + // is done in-place. Other consumers of this modified shared initializer must get a Transpose node inserted after + // the initializer to undo the in-place transformation. // - // Create QDQ node unit for Transpose after DQ that provides graph output. - // DQ -> Transpose -> graph output => - // DQ -> Transpose -> Q -> DQ -> graph output + // Example: + // in_place_transposed_const ---+--> DQ0 --> Op0 --> Q0 --> + // | + // | + // +--> Transpose[to undo] --> DQ1 --> Op1 -> Q1 --> // - // Remove empty DQ -> Q pair from moving a Transpose downstream or a Transpose being cancelled out. - // DQ -> Q -> consumer node => - // consumer node - + // In the above example, constant folding would remove the Transpose and the graph would become: + // in_place_transposed_const --> DQ0 --> Op0 --> Q0 --> + // new_const --> DQ1 --> Op1 --> Q1 --> + // + // If the shared initializer needs to be broadcast before being transposed in-place, then we'll also end up + // with a redundant Squeeze node to undo the broadcast/unsqueeze. + // + // Example: + // in_place_unsqueezed_transposed_const ---+--> DQ0 --> Op0 --> Q0 --> + // | + // | + // +--> Transpose[to undo] --> Squeeze[to undo] --> DQ1 --> Op1 -> Q1 --> + // + // In this case, constant folding would remove both the Transpose and Squeeze nodes: + // in_place_unsqueezed_transposed_const --> DQ0 --> Op0 --> Q0 --> + // new_const --> DQ1 --> Op1 --> Q1 --> auto graph_nodes = ctx.graph.Nodes(); - for (size_t i = 1; i < graph_nodes.size(); i++) { + for (size_t i = 0; i < graph_nodes.size(); i++) { auto& node = *graph_nodes[i]; - if (!can_modify_node(node)) { + if (!CanModifyNode(ctx, node)) { continue; } - for (size_t i_idx = 0, i_end = node.Inputs().size(); i_idx < i_end; ++i_idx) { - // any change requires a DQ as the input to the current node - auto input_node = ctx.graph.GetNodeProducingOutput(node.Inputs()[i_idx]); - if (!input_node || input_node->OpType() != "DequantizeLinear") { - continue; - } - - auto& dq_node = *input_node; - std::unique_ptr single_consumer_node; - - // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. - if (node.OpType() == "QuantizeLinear") { - // we don't need to check scale and zp inputs, and we may remove nodes invalidating `node` if we - // continue with the loop of inputs so set i_end to bail - i_end = 1; - - auto& q_node = node; - if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && - OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && - CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { - // connect Q consumer to DQ input - for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { - if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { - single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); - // break; in theory the Q might be providing multiple inputs. - } - } - - // disconnect other nodes and remove - dq_node.SetInput(0, ""); - q_node.SetInput(0, ""); - ctx.graph.RemoveNode(dq_node); - ctx.graph.RemoveNode(q_node); - - changed = true; - continue; - } - } - - // DQ -> Transpose => DQ -> Transpose -> Q -> DQ if needed - if (node.OpType() == "Transpose") { - auto& transpose_node = node; - - // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. - // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. - auto transpose_output = transpose_node.Outputs()[0]; - auto consumers = ctx.graph.GetValueConsumers(transpose_output); - if (consumers->nodes.empty()) { - // DQ -> Transpose -> graph output - } else { - if (consumers->nodes.size() > 1) { - // unexpected to have DQ -> Transpose -> multiple consumers - continue; - } - - if (consumers->nodes[0]->OpType() == "QuantizeLinear") { - // already in QDQ node unit - continue; - } - } - - // Add Q -> DQ after the DQ -> Transpose - if (MakeQDQNodeUnit(ctx.graph, dq_node)) { - changed = true; - } - } + if (TryConstantFoldNode(ctx, node)) { + changed = true; } } + // Run 'fix up' pass for QDQ node units. + if (FixQDQNodeUnits(ctx)) { + changed = true; + } + result.graph_modified = changed; return result; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c1cd21570a6a4..16f0752e3f603 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1186,8 +1186,11 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool std::move(cpu_allocator), debug_graph_fn)); // Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in - // a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which - // takes care of most models without needing this. + // a QDQ format model. The transpose optimizer can now do the following, which takes care of most models without + // needing this. + // - Look past DQ nodes to directly update initializers in-place. + // - Fix-up broken Transpose QDQ groups. + // - Constant fold inserted Squeeze and Transpose ops. // // if (modified) { // ORT_RETURN_IF_ERROR_SESSIONID_( diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index bae37adb19eef..ea2823916798e 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4606,8 +4606,213 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { } } } + +// Layout transform's cost function aggressively pushes down transposes with channel-first or channel-last perms. +// This can lead to a situation where a channel-fist/last Transpose gets stuck after being pushed down an Unsqueeze +// that makes the Transpose's perm no longer channel-first/last. This breaks the QDQ node units for both the +// Unsqueeze and the Transpose: DQ -> Unsqueeze -> Transpose -> Q. +// The transpose optimizer should insert a Q -> DQ pair between the Unsqueeze and Transpose nodes to fix both +// QDQ node units: DQ -> Unsqueeze -> Q[new] -> DQ[new] -> Transpose -> Q +TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { + Status status; + + // Using a sub-model extracted from a model that we tried to run with QNN EP. + auto model_uri = ORT_TSTR("testdata/layout_transform_fix_transpose_without_dq.qdq.onnx"); + + SessionOptions so; + + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // Set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 2) << "Should have 2 transposes remaining."; + + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + // All Transpose nodes should be in QDQ node units. + if (node.OpType() == "Transpose") { + for (auto cur_input = node.InputNodesBegin(), end = node.InputNodesEnd(); cur_input != end; ++cur_input) { + EXPECT_EQ(cur_input->OpType(), "DequantizeLinear"); + } + + for (auto cur_output = node.OutputNodesBegin(), end = node.OutputNodesEnd(); cur_output != end; ++cur_output) { + EXPECT_EQ(cur_output->OpType(), "QuantizeLinear"); + } + } + } +} + +// Tests the transpose optimizer's ability to constant fold inserted Transpose and Squeeze nodes. +// After the core transpose optimization loop, the test model contains the following "constant foldable" sequence: +// +// unsqueezed_transposed_weight --+--> Transpose ---> Squeeze ---> DequantizeLinear ---> Mul ---> ... +// | +// +--> DequantizeLinear --> Mul --> ... +// +// After constant-folding the Transpose and Squeeze nodes, the final model looks like: +// +// new_folded_weight ---> DequantizeLinear ---> Mul ---> ... +// unsqueezed_transposed_weight ---> DequantizeLinear ---> Mul ---> ... +TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { + Status status; + + // The test model has a shared initializer that is unsqueezed and transposed in-place for one consumer. + // The other consumer gets a Transpose -> Squeeze sequence inserted before its input. + // This Transpose -> Squeeze sequence should get constant-folded. + auto model_uri = ORT_TSTR("testdata/layout_transform_const_folding.qdq.onnx"); + + SessionOptions so; + + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // Set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + // All Squeeze nodes should have been constant folded by transpose optimizer. + ASSERT_EQ(op_to_count["Squeeze"], 0) << "Should have 0 Squeeze nodes remaining."; + + // 1 transpose is constant-folded, 1 is canceled, and 1 remains. + ASSERT_EQ(op_to_count["Transpose"], 1) << "Should have 1 transpose remaining."; + + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + // All Transpose nodes should be in QDQ node units. + if (node.OpType() == "Transpose") { + for (auto cur_input = node.InputNodesBegin(), end = node.InputNodesEnd(); cur_input != end; ++cur_input) { + EXPECT_EQ(cur_input->OpType(), "DequantizeLinear"); + } + + for (auto cur_output = node.OutputNodesBegin(), end = node.OutputNodesEnd(); cur_output != end; ++cur_output) { + EXPECT_EQ(cur_output->OpType(), "QuantizeLinear"); + } + } + } +} #endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) +// Checks that a model that is not processed by the transpose optimizer produces the same +// results as the same model that undergoes transpose optimization with constant folding +// of Transpose and Squeeze nodes. +TEST(TransposeOptimizerTests, ConstantFoldTransposeAndSqueezeOutputCorrectness) { + // This test model has a shared initializer that is unsqueezed and transposed in-place for one consumer. + // The other consumer gets a Transpose -> Squeeze sequence inserted before its input. + // This Transpose -> Squeeze sequence should get constant-folded. + auto model_uri = ORT_TSTR("testdata/layout_transform_const_folding.qdq.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 3, 3}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + std::vector input1_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + OrtValue input0; + OrtValue input1; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input0_data, &input0); + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input1_data, &input1); + + NameMLValMap feeds{{"input0", input0}, {"input1", input1}}; + + std::vector output_names{"output0", "output1"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // We call the ONNX transpose optimizer directly to use a custom cost check function. + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + // Use a custom optimization cost check that aggressively pushes channel-last or channel-first transposes. + // This causes an existing transpose to be pushed through an op (Op1) with a shared initializer input. The other + // consumer (Op0) of the shared initializer will get a "constant-foldable" sequence between itself and its input. + // shared_const --+--> Transpose --> Squeeze --> Op0 + // | + // +--> Op1 + auto custom_cost_fn = + [](const alias_oto::api::GraphRef& /* graph */, + const alias_oto::api::NodeRef& /* node */, + const std::vector& perm, + const std::unordered_set& /* outputs_leading_to_transpose */) -> alias_oto::CostCheckResult { + if (perm == alias_oto::ChannelFirstToLastPerm(perm.size()) || + perm == alias_oto::ChannelLastToFirstPerm(perm.size())) { + return alias_oto::CostCheckResult::kPushTranspose; + } + + return alias_oto::CostCheckResult::kFallThrough; + }; + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph, /*provider_type*/ "", custom_cost_fn); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), "transpose_opt_updated_const_fold.onnx")); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Squeeze"], 0) << "The Squeeze nodes should have been folded."; + EXPECT_EQ(op_to_count["Transpose"], 1) << "1 inserted Transpose should be constant-folded. " + << "Only the pre-existing Transpose should remain."; + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); + ASSERT_THAT(fetches_orig[1].Get().DataAsSpan(), + testing::ContainerEq(fetches[1].Get().DataAsSpan())); +} + static void CheckSharedInitializerHandling(bool broadcast) { auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); diff --git a/onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx b/onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx new file mode 100644 index 0000000000000..d5fe768b47efa Binary files /dev/null and b/onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx differ diff --git a/onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx b/onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx new file mode 100644 index 0000000000000..62c4e60c2c43b Binary files /dev/null and b/onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx differ diff --git a/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py b/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py new file mode 100644 index 0000000000000..34b2ff923b18f --- /dev/null +++ b/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.shape_inference import quant_pre_process + + +class DataReader(CalibrationDataReader): + def __init__(self, model_path: str): + self.enum_data = None + + # Use inference session to get input shape. + session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + inputs = session.get_inputs() + + self.data_list = [] + + # Generate 10 random float32 inputs + for _ in range(10): + input_data = {inp.name: np.random.random(inp.shape).astype(np.float32) for inp in inputs} + self.data_list.append(input_data) + + self.datasize = len(self.data_list) + + def get_next(self): + if self.enum_data is None: + self.enum_data = iter(self.data_list) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None + + +if __name__ == "__main__": + """ + Creates a QDQ model with a shared initializer. + The transpose optimizer will generate a (weight -> Transpose -> Squeeze) sequence that can be constant folded + by the tranpose optimizer itself. + """ + shape = (1, 3, 3, 3) + + input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, shape) + input1 = onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.FLOAT, shape) + output0 = onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None) + output1 = onnx.helper.make_tensor_value_info("output1", onnx.TensorProto.FLOAT, None) + + # Shared weight (will be unsqueezed and fed into a Squeeze op by layout transformation). + const_1_weight = onnx.numpy_helper.from_array(np.array([1.0] * shape[-1], dtype=np.float32), "const_1_weight") + + # Transpose with channel-first perm + transpose_node = onnx.helper.make_node( + "Transpose", ["input0"], ["transpose_out"], name="transpose_node", perm=(0, 3, 1, 2) + ) + + # Mul0 + mul0_node = onnx.helper.make_node("Mul", ["transpose_out", "const_1_weight"], ["mul0_out"], name="mul0_node") + + # Mul1 + mul1_node = onnx.helper.make_node("Mul", ["input1", "const_1_weight"], ["output1"], name="mul1_node") + + # Conv0 + conv_w_shape = (1, 3, 2, 2) + conv_weight_data = np.random.normal(-1.0, 1.0, conv_w_shape).astype(np.float32) + conv_weight = onnx.numpy_helper.from_array(conv_weight_data, "conv_weight") + conv_node = onnx.helper.make_node("Conv", ["mul0_out", "conv_weight"], ["output0"], name="conv_node") + + graph = onnx.helper.make_graph( + [transpose_node, mul0_node, mul1_node, conv_node], + "layout_transform_const_folding", + [input0, input1], + [output0, output1], + initializer=[const_1_weight, conv_weight], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 19), + ] + f32_model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + print("[INFO]: Running onnx.checker on f32 model") + f32_model = onnx.shape_inference.infer_shapes(f32_model) + onnx.checker.check_model(f32_model, True) + f32_model_path = "layout_transform_const_folding.f32.onnx" + + print(f"[INFO]: Saving {f32_model_path}") + onnx.save_model(f32_model, f32_model_path) + + # Quantize model + qdq_model_path = "layout_transform_const_folding.qdq.onnx" + print("[INFO]: Creating QDQ model") + quantize_static( + f32_model_path, + qdq_model_path, + DataReader(f32_model_path), + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + extra_options={"DedicatedQDQPair": True, "ForceQuantizeNoInputCheck": True}, + ) + quant_pre_process(qdq_model_path, qdq_model_path) + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + onnx.save_model(qdq_model, qdq_model_path) + print(f"[INFO]: Created QDQ model {qdq_model_path}")