From 8acf60f35c58bfb1358c8316f522062f120a4366 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 20 May 2024 20:19:06 -0700 Subject: [PATCH] Layout transform: Fix-up QDQ units and add constant folding (#20685) ### Description #### Problem 1: Broken Transpose QDQ unit Layout transform's specialized cost function aggressively pushes down transposes with channel-first or channel-last perms. This can lead to a situation where a channel-fist/last Transpose gets stuck after being pushed through an Unsqueeze node that makes the Transpose's perm no longer channel-first/last. At this point, the specialized cost function defers to the default const function, which does not see a need to continue pushing this transpose node. This breaks the QDQ node units for both the Unsqueeze and the Transpose: DQ -> Unsqueeze -> Transpose -> Q. image The transpose optimizer should insert a Q -> DQ pair between the Unsqueeze and Transpose nodes to fix both QDQ node units: DQ -> Unsqueeze -> Q[new] -> DQ[new] -> Transpose -> Q image #### Problem 2: Inserted Squeeze/Transpose nodes should be constant folded when possible. The transpose optimizer inserts Squeeze (and Transpose) ops between an initializer and a DQ to counteract the effect of Unsqueezing that initializer if it is consumed by multiple nodes. This results in a graph where the inserted nodes are not in valid node units: Original graph where two Mul nodes share a common initializer input: image Resulting graph after transpose optimization without constant folding: image Here, the circled Transpose and Squeeze nodes operate on a quantized integer type but are not in valid QDQ node units. The solution is to run constant folding, which results in: image ### Motivation and Context Improve the layout transformation to allow more models to run on EPs that prefer the channel-last layout. --------- Co-authored-by: Scott McKay --- .../onnx_transpose_optimization.cc | 637 +++++++++++++++--- onnxruntime/core/session/inference_session.cc | 7 +- .../optimizer/transpose_optimizer_test.cc | 205 ++++++ .../layout_transform_const_folding.qdq.onnx | Bin 0 -> 4845 bytes ...ransform_fix_transpose_without_dq.qdq.onnx | Bin 0 -> 24927 bytes ...make_qdq_layout_transform_const_folding.py | 109 +++ 6 files changed, 852 insertions(+), 106 deletions(-) create mode 100644 onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx create mode 100644 onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx create mode 100644 onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index c479b685f9267..e6ffd0d91372b 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -98,14 +99,81 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } -// Use to create a QuantizeLinear or DequantizeLinear node. Does not update output ValueInfo. Adds axis if needed. -static std::unique_ptr MakeQOrDQ(api::GraphRef& graph, std::string_view domain, std::string_view op_type, - std::vector inputs, - std::optional axis) { - std::unique_ptr node = graph.AddNode(op_type, inputs, /* num_outputs */ 1, domain); - // only set if provided and not the default - if (axis && axis != 1) { - node->SetAttributeInt("axis", *axis); +/// +/// Sets an attribute on a node if the attribute value is valid and differs from the default value. +/// +/// Node on which to set the attribute +/// Attribute's name +/// Attribute value to set +/// Default attribute value +static void SetAttrIfNotDefault(api::NodeRef& node, std::string_view attr_name, + std::optional attr_val, int64_t attr_default_val) { + if (attr_val && attr_val != attr_default_val) { + node.SetAttributeInt(attr_name, *attr_val); + } +} + +/// +/// Adds a new QuantizeLinear node to the graph. Does not update the output's ValueInfo data. +/// +/// Graph into which to add the new node +/// Domain for the new node +/// List of input names for the new node +/// Optional 'axis' attribute value +/// Optional 'block_size' attribute value +/// Optional 'output_dtype' attribute value +/// Optional 'saturate' attribute value +/// Reference to the new QuantizeLinear node +static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::string_view domain, + std::vector inputs, + std::optional axis, + std::optional block_size, + std::optional output_dtype, + std::optional saturate) { + std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + + SetAttrIfNotDefault(*node, "axis", axis, 1); + + if (auto opset = graph.Opset(domain); opset) { + const int64_t required_opset_1 = IsOnnxDomain(domain) ? 19 : 1; + const int64_t required_opset_2 = IsOnnxDomain(domain) ? 21 : 1; + + if (*opset >= required_opset_1) { + SetAttrIfNotDefault(*node, "saturate", saturate, 1); + } + + if (*opset >= required_opset_2) { + SetAttrIfNotDefault(*node, "block_size", block_size, 0); + SetAttrIfNotDefault(*node, "output_dtype", output_dtype, 0); + } + } + + return node; +} + +/// +/// Adds a new DequantizeLinear node to the graph. Does not update the output's ValueInfo data. +/// +/// Graph into which to add the new node +/// Domain for the new node +/// List of input names for the new node +/// Optional 'axis' attribute value +/// Optional 'block_size' attribute value +/// Reference to the new DequantizeLinear node +static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std::string_view domain, + std::vector inputs, + std::optional axis, + std::optional block_size) { + std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + + SetAttrIfNotDefault(*node, "axis", axis, 1); + + if (auto opset = graph.Opset(domain); opset) { + const int64_t required_opset = IsOnnxDomain(domain) ? 21 : 1; + + if (*opset >= required_opset) { + SetAttrIfNotDefault(*node, "block_size", block_size, 0); + } } return node; @@ -263,6 +331,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { auto update_dq_axis = scale_shape && !scale_shape->empty(); int64_t axis = dq_node.GetAttributeIntDefault("axis", 1); + // TODO(adrianlizarraga): Also need to update axis if Unsqueeze inserts a 1 before the axis dim. if (update_dq_axis && is_transpose) { // update axis. auto perm = GetPermAttrIfValid(next_node); @@ -281,7 +350,8 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { } // Add Q - auto new_q_node = MakeQOrDQ(graph, dq_domain, "QuantizeLinear", inputs, axis); + auto new_q_node = MakeQuantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size"), + dq_node.GetAttributeInt("output_dtype"), dq_node.GetAttributeInt("saturate")); auto q_node_outputs = new_q_node->Outputs(); // copy value info from the dq input for the type information, and update the shape to match next_node's output @@ -293,7 +363,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { inputs[0] = new_q_node->Outputs()[0]; // Add DQ - auto new_dq_node = MakeQOrDQ(graph, dq_domain, "DequantizeLinear", inputs, axis); + auto new_dq_node = MakeDequantizeOp(graph, dq_domain, inputs, axis, dq_node.GetAttributeInt("block_size")); auto dq_node_outputs = new_dq_node->Outputs(); // straight copy of value info as the type and shape are the same as next_node's output @@ -499,6 +569,51 @@ static std::vector UnsqueezeShape(gsl::span shape, const return new_shape; } +/// +/// Returns a new squeezed shape without the dimensions of value 1 indicated by the given axes. +/// +/// Input shape to squeeze +/// List of integers indicating the dimensions to squeeze +/// New squeezed shape +static std::vector SqueezeShape(gsl::span shape, const std::vector& axes) { + const size_t init_rank = shape.size(); + std::vector pos_axes(axes.begin(), axes.end()); + + // Normalize negative axis values. + for (size_t i = 0; i < pos_axes.size(); i++) { + if (pos_axes[i] < 0) { + pos_axes[i] += static_cast(init_rank); + } + } + + // Sort positive axis values and remove duplicates. + std::sort(pos_axes.begin(), pos_axes.end()); + pos_axes.erase(std::unique(pos_axes.begin(), pos_axes.end()), pos_axes.end()); + + assert(shape.size() >= pos_axes.size()); + + std::vector new_shape; + size_t j = 0; + + for (size_t i = 0; i < shape.size(); i++) { + if (pos_axes.empty() && shape[i] == 1) { + // If axes is empty, skip all shape values equal to 1. + continue; + } + + if ((j < pos_axes.size()) && (i == gsl::narrow_cast(pos_axes[j]))) { + // Skip shape dim if it appears in axes. shape[i] must be 1. + assert(shape[i] == 1); + j++; + continue; + } + + new_shape.push_back(shape[i]); + } + + return new_shape; +} + // Computes new perm for unsqueezed version of a tensor. Unsafe if axes/perm are invalid or have negative values. // New perm reorders non-1 dimensions in the same way and leaves 1-dims from unsqueeze unchanged. // Ex: @@ -2374,6 +2489,377 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return ctx; } +/// +/// Returns true if the transpose optimizer can modify the given node. +/// +/// Optimizer context +/// Node to check +/// True if allowed to modify the given node +static bool CanModifyNode(const OptimizerCtx& ctx, const api::NodeRef& node) { + const auto& node_ep = node.GetExecutionProviderType(); + bool can_modify = false; + + if (node_ep.empty()) { + // Unassigned nodes can always be modified + can_modify = true; + } else if (node_ep == ctx.provider_type) { + // We can also modify if the EP name in provider_type is not empty and the node is assigned to that EP. + can_modify = true; + } + + return can_modify; +} + +/// +/// Try to remove empty DQ -> Q pair that results from moving a Transpose downstream or a Transpose being canceled out. +/// (DQ -> Q -> consumer node) => consumer node +/// +/// Optimizer context +/// QuantizeLinear node +/// True if an empty DQ -> Q was removed +static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { + assert(q_node.OpType() == "QuantizeLinear"); + + // Require a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(q_node.Inputs()[0]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { + return false; + } + + auto& dq_node = *input_node; + std::unique_ptr single_consumer_node; + + // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. + if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && + OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && + CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // connect Q consumer to DQ input + for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { + if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { + single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); + // break; in theory the Q might be providing multiple inputs. + } + } + + // disconnect other nodes and remove + dq_node.SetInput(0, ""); + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); + + return true; + } + + return false; +} + +/// +/// Try to repair a broken QDQ Transpose node unit that is missing the Q at its output. +/// The Transpose could be blocked on the Op inside the QDQ node unit: +/// DQ -> Transpose -> Op -> Q => +/// DQ -> Transpose -> Q[new] -> DQ[new] -> Op -> Q +/// Alternatively, the Transpose could be providing a graph output: +/// DQ -> Transpose -> graph output => +/// DQ -> Transpose -> Q[new] -> DQ[new] -> graph output +/// +/// Optimizer context +/// Transpose node +/// True if the QDQ node unit was repaired +static bool TryFixTransposeMissingQ(OptimizerCtx& ctx, api::NodeRef& transpose_node) { + assert(transpose_node.OpType() == "Transpose"); + + // Require a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { + return false; + } + + auto& dq_node = *input_node; + + // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. + // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. + auto transpose_output = transpose_node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(transpose_output); + if (consumers->nodes.empty()) { + // DQ -> Transpose -> graph output + } else { + if (consumers->nodes.size() > 1) { + // unexpected to have DQ -> Transpose -> multiple consumers + return false; + } + + if (consumers->nodes[0]->OpType() == "QuantizeLinear") { + // already in QDQ node unit + return false; + } + } + + // Add Q -> DQ after the DQ -> Transpose + return MakeQDQNodeUnit(ctx.graph, dq_node); +} + +/// +/// Fixes a Transpose QDQ node unit that is missing the DQ at its input due to the Transpose being blocked on the Q. +/// Inserts a Q -> DQ pair before the sequence Transpose -> Q by using the scale and zp info from the Q node. +/// Before: prev_node -> Transpose -> Q +/// After: prev_node -> Q[new] -> DQ[new] -> Transpose -> Q +/// +/// Transpose node. +/// True if Q -> DQ insertion was successful. +static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_node) { + assert(transpose_node.OpType() == "Transpose"); + auto transpose_input_name = transpose_node.Inputs()[0]; + auto transpose_output_name = transpose_node.Outputs()[0]; + + // Require a Q as the single consumer of this transpose node's output. + std::unique_ptr maybe_q_node; + if (!OutputValueHasSingleConsumerNode(ctx.graph, transpose_node, 0, maybe_q_node) || + maybe_q_node->OpType() != "QuantizeLinear") { + return false; + } + + // Get the node upstream from the Transpose. + auto prev_node = ctx.graph.GetNodeProducingOutput(transpose_input_name); + if (prev_node == nullptr) { + // Transpose consumes a graph input or constant. Skip. + return false; + } + + if (prev_node->OpType() == "DequantizeLinear") { + // Transpose is already in a QDQ node unit. + return false; + } + + auto& q_node = *maybe_q_node; + const auto q_node_inputs = q_node.Inputs(); + + auto transpose_output_consumers = ctx.graph.GetValueConsumers(transpose_output_name); + if (!transpose_output_consumers->comprehensive || transpose_output_consumers->nodes.size() != 1) { + // Q node should be the only consumer for the Transpose. + return false; + } + + auto transpose_input_consumers = ctx.graph.GetValueConsumers(transpose_input_name); + if (transpose_input_consumers->nodes.size() != 1) { + // The transpose node should be the only consumer of its own input. + return false; + } + + const auto q_domain = q_node.Domain(); + const auto scale_input = q_node_inputs[1]; + const auto scale_value_info = ctx.graph.GetValueInfo(scale_input); + std::optional zp_input; + std::optional> zp_value_info; + + auto scale_shape = scale_value_info->Shape(); + if (!scale_shape) { + // Axis potentially needs updating due to the transpose but we don't have the required info to do it. + return false; + } + + if (q_node_inputs.size() > 2) { + zp_input = q_node_inputs[2]; + zp_value_info = ctx.graph.GetValueInfo(zp_input.value()); + } + + // Per-axis quantization if not a scalar (shape is empty for scalar). + // note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, + // so we have to check the shape. + const bool update_axis = scale_shape && !scale_shape->empty(); + int64_t axis = q_node.GetAttributeIntDefault("axis", 1); + + if (update_axis) { + auto perm = GetPermAttrIfValid(transpose_node); + assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid + NormalizeAndValidateAxis(axis, scale_shape->size()); + axis = (*perm)[gsl::narrow_cast(axis)]; // Note: do not invert permutation. + } + + auto transpose_input_shape = ctx.graph.GetValueInfo(transpose_input_name)->Shape(); + + // Setup Q node inputs. + // We don't connect it to the node preceding the Transpose yet as we will move the output of that to the new DQ first. + std::vector inputs = {"", scale_input}; + if (zp_input) { + inputs.push_back(zp_input.value()); + } + + // Add Q + auto new_q_node = MakeQuantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size"), + q_node.GetAttributeInt("output_dtype"), q_node.GetAttributeInt("saturate")); + auto new_q_node_output = new_q_node->Outputs()[0]; + + // Copy value info from the q output for the type information, and update the shape to match Transpose's input + ctx.graph.CopyValueInfo(q_node.Outputs()[0], new_q_node_output); // Q produces same type as the q_node output + auto new_q_node_value_info = ctx.graph.GetValueInfo(new_q_node_output); + new_q_node_value_info->SetShape(transpose_input_shape ? &*transpose_input_shape : nullptr); + + // update input to connect the DQ to the Q we just added. re-use scale and zp. + inputs[0] = new_q_node->Outputs()[0]; + + // Add new DQ. + auto new_dq_node = MakeDequantizeOp(ctx.graph, q_domain, inputs, axis, q_node.GetAttributeInt("block_size")); + auto new_dq_node_output = new_dq_node->Outputs()[0]; + ctx.graph.CopyValueInfo(transpose_input_name, new_dq_node_output); + + auto prev_node_outputs = prev_node->Outputs(); + size_t prev_node_output_idx = 0; + for (size_t out_idx = 0; out_idx < prev_node_outputs.size(); ++out_idx) { + if (prev_node_outputs[out_idx] == transpose_input_name) { + prev_node_output_idx = out_idx; + break; + } + } + + // move prev_node output to the new DQ node, and connect prev_node with the new Q node + ctx.graph.MoveOutput(*prev_node, prev_node_output_idx, *new_dq_node, 0); + std::string_view new_prev_node_output_name = prev_node->Outputs()[prev_node_output_idx]; + new_q_node->SetInput(0, new_prev_node_output_name); + ctx.graph.CopyValueInfo(new_dq_node_output, new_prev_node_output_name); + + return true; +} + +/// +/// Fixes QDQ node units that may have been left in an invalid state after the core transpose optimization pass. +/// +/// Optimizer context +/// True if the graph was modified +static bool FixQDQNodeUnits(OptimizerCtx& ctx) { + bool changed = false; + + auto graph_nodes = ctx.graph.Nodes(); + for (size_t i = 0; i < graph_nodes.size(); i++) { + auto& node = *graph_nodes[i]; + + if (!CanModifyNode(ctx, node)) { + continue; + } + + std::string_view op_type = node.OpType(); + + if (op_type == "QuantizeLinear") { + if (TryRemoveEmptyDQQ(ctx, node)) { + changed = true; + continue; + } + } else if (op_type == "Transpose") { + if (TryFixTransposeMissingQ(ctx, node)) { + changed = true; + continue; + } + + if (TryFixTransposeMissingDQ(ctx, node)) { + changed = true; + continue; + } + } + } + + return changed; +} + +/// +/// Try to constant fold Transpose or Squeeze nodes if their input is a constant. +/// Returns true if the graph was modified (i.e., at least one of the consumers received a constant-folded value). +/// +/// Optimization context state +/// Squeeze or Transpose node to try to constant-fold +/// True if graph was modified. The node may not have been removed in either case. +static bool TryConstantFoldNode(OptimizerCtx& ctx, api::NodeRef& node) { + std::string_view node_op_type = node.OpType(); + const bool is_transpose = node_op_type == "Transpose"; + const bool is_squeeze = node_op_type == "Squeeze"; + + if (!is_transpose && !is_squeeze) { + return false; + } + + std::string_view node_input_name = node.Inputs()[0]; + auto const_input = ctx.graph.GetLocalConstant(node_input_name); + if (const_input == nullptr) { + // Doesn't have a constant input. Skip. + return false; + } + + std::string_view node_output_name = node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(node_output_name); + + if (consumers->nodes.empty()) { + // No consumers Skip. + return false; + } + + std::string_view new_initializer_name; + + // Create new squeezed or transposed initializer. + // Once we create this new initializer, we're committed to modifying the graph. + if (is_transpose) { + std::optional> perm = GetPermAttrIfValid(node); + if (perm == std::nullopt) { + // Invalid transpose perm attribute. Should not happen. Skip. + return false; + } + + new_initializer_name = ctx.graph.AddInitializer(const_input->DType(), + const_input->Shape(), + const_input->Data()); + ctx.graph.TransposeInitializer(new_initializer_name, *perm); + } else { + assert(is_squeeze); + std::optional> squeeze_axes = ReadFromAttrOrInput(ctx, node, "axes", /*inp_index*/ 1, + /*opset*/ 13); + if (squeeze_axes == std::nullopt) { + // Invalid Squeeze axes value. Should not happen. Skip. + return false; + } + + auto squeezed_shape = SqueezeShape(const_input->Shape(), *squeeze_axes); + new_initializer_name = ctx.graph.AddInitializer(const_input->DType(), + const_input->Shape(), + const_input->Data()); + ctx.graph.ReshapeInitializer(new_initializer_name, squeezed_shape); + } + + // Iterate through consumers and replace their input(s) with the new initializer. + for (auto& consumer : consumers->nodes) { + std::vector inputs = consumer->Inputs(); + + for (size_t input_idx = 0; input_idx < inputs.size(); input_idx++) { + if (inputs[input_idx] == node_output_name) { + consumer->SetInput(input_idx, new_initializer_name); + } + } + } + + // Remove original node if its output is unused. + if (!ctx.graph.HasValueConsumers(node_output_name)) { + ctx.graph.RemoveNode(node); + } + + // Remove old initializer if no longer used. + // Will not happen if this initializer was unsqueezed/transposed in-place for another consumer. + // Will happen if this initializer is a result of a previous constant-folding operation. + // + // Example: shared_const --+--> Transpose --> Squeeze --> Op0 + // | + // +--> Op1 + // + // The first call to TryConstantFoldNode(transpose) does not remove shared_const because it is used by 'Op1'. + // However, the graph becomes: + // transposed_const --> Squeeze --> Op0 --> + // shared_const --> Op1 --> + // + // The subsequent call to TryConstantFoldNode(squeeze) removes transposed_const from the graph, and we end up with: + // transposed_squeezed_const --> Op0 + // shared_const --> Op1 + if (!ctx.graph.HasValueConsumers(node_input_name)) { + ctx.graph.RemoveInitializer(node_input_name); + } + + return true; +} + // Performs optimization. General algorithm: iterate over nodes in topological order. If a node has a transpose // as input, push it through if the transpose cost does not increase and is likely to decrease. OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { @@ -2444,20 +2930,6 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { // Existing nodes assigned to the CPU EP can be modified. // New nodes can be created and are directly assigned to the CPU EP by setting onnxruntime::ApiGraph::new_node_ep_ // - const auto can_modify_node = [&ctx](const api::NodeRef& node) { - const auto& node_ep = node.GetExecutionProviderType(); - bool can_modify = false; - - if (node_ep.empty()) { - // unassigned nodes can always be modified - can_modify = true; - } else if (node_ep == ctx.provider_type) { - // we can also modify if the EP name in provider_type is not empty and the node is assigned to that EP. - can_modify = true; - } - - return can_modify; - }; // Optimize graph. Nodes will be modified during iteration, but nodes are never deleted before we reach them. // New transpose nodes are inserted, but always as an input to an existing node. @@ -2467,7 +2939,7 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { have_dq = true; } - if (!can_modify_node(node)) { + if (!CanModifyNode(ctx, node)) { continue; } @@ -2495,97 +2967,54 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { return result; } - // Run 'fix up' pass for QDQ node units. + // Run constant-folding for Transpose and Squeeze ops to fold sequences like (const --> Squeeze --> DQ) + // into (squeezed_const --> DQ). // - // Repair broken QDQ node unit from Transpose being blocked on Op inside a QDQ node unit. - // DQ -> Transpose -> Op -> Q => - // DQ -> Transpose -> Q -> DQ -> Op -> Q + // These constant-foldable sequences are created when a transpose is pushed through a node that has a shared + // initializer as one of its inputs. The node's inputs must be transposed, and for initializer inputs this transpose + // is done in-place. Other consumers of this modified shared initializer must get a Transpose node inserted after + // the initializer to undo the in-place transformation. // - // Create QDQ node unit for Transpose after DQ that provides graph output. - // DQ -> Transpose -> graph output => - // DQ -> Transpose -> Q -> DQ -> graph output + // Example: + // in_place_transposed_const ---+--> DQ0 --> Op0 --> Q0 --> + // | + // | + // +--> Transpose[to undo] --> DQ1 --> Op1 -> Q1 --> // - // Remove empty DQ -> Q pair from moving a Transpose downstream or a Transpose being cancelled out. - // DQ -> Q -> consumer node => - // consumer node - + // In the above example, constant folding would remove the Transpose and the graph would become: + // in_place_transposed_const --> DQ0 --> Op0 --> Q0 --> + // new_const --> DQ1 --> Op1 --> Q1 --> + // + // If the shared initializer needs to be broadcast before being transposed in-place, then we'll also end up + // with a redundant Squeeze node to undo the broadcast/unsqueeze. + // + // Example: + // in_place_unsqueezed_transposed_const ---+--> DQ0 --> Op0 --> Q0 --> + // | + // | + // +--> Transpose[to undo] --> Squeeze[to undo] --> DQ1 --> Op1 -> Q1 --> + // + // In this case, constant folding would remove both the Transpose and Squeeze nodes: + // in_place_unsqueezed_transposed_const --> DQ0 --> Op0 --> Q0 --> + // new_const --> DQ1 --> Op1 --> Q1 --> auto graph_nodes = ctx.graph.Nodes(); - for (size_t i = 1; i < graph_nodes.size(); i++) { + for (size_t i = 0; i < graph_nodes.size(); i++) { auto& node = *graph_nodes[i]; - if (!can_modify_node(node)) { + if (!CanModifyNode(ctx, node)) { continue; } - for (size_t i_idx = 0, i_end = node.Inputs().size(); i_idx < i_end; ++i_idx) { - // any change requires a DQ as the input to the current node - auto input_node = ctx.graph.GetNodeProducingOutput(node.Inputs()[i_idx]); - if (!input_node || input_node->OpType() != "DequantizeLinear") { - continue; - } - - auto& dq_node = *input_node; - std::unique_ptr single_consumer_node; - - // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. - if (node.OpType() == "QuantizeLinear") { - // we don't need to check scale and zp inputs, and we may remove nodes invalidating `node` if we - // continue with the loop of inputs so set i_end to bail - i_end = 1; - - auto& q_node = node; - if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && - OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && - CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { - // connect Q consumer to DQ input - for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { - if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { - single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); - // break; in theory the Q might be providing multiple inputs. - } - } - - // disconnect other nodes and remove - dq_node.SetInput(0, ""); - q_node.SetInput(0, ""); - ctx.graph.RemoveNode(dq_node); - ctx.graph.RemoveNode(q_node); - - changed = true; - continue; - } - } - - // DQ -> Transpose => DQ -> Transpose -> Q -> DQ if needed - if (node.OpType() == "Transpose") { - auto& transpose_node = node; - - // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. - // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. - auto transpose_output = transpose_node.Outputs()[0]; - auto consumers = ctx.graph.GetValueConsumers(transpose_output); - if (consumers->nodes.empty()) { - // DQ -> Transpose -> graph output - } else { - if (consumers->nodes.size() > 1) { - // unexpected to have DQ -> Transpose -> multiple consumers - continue; - } - - if (consumers->nodes[0]->OpType() == "QuantizeLinear") { - // already in QDQ node unit - continue; - } - } - - // Add Q -> DQ after the DQ -> Transpose - if (MakeQDQNodeUnit(ctx.graph, dq_node)) { - changed = true; - } - } + if (TryConstantFoldNode(ctx, node)) { + changed = true; } } + // Run 'fix up' pass for QDQ node units. + if (FixQDQNodeUnits(ctx)) { + changed = true; + } + result.graph_modified = changed; return result; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c1cd21570a6a4..16f0752e3f603 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1186,8 +1186,11 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool std::move(cpu_allocator), debug_graph_fn)); // Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in - // a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which - // takes care of most models without needing this. + // a QDQ format model. The transpose optimizer can now do the following, which takes care of most models without + // needing this. + // - Look past DQ nodes to directly update initializers in-place. + // - Fix-up broken Transpose QDQ groups. + // - Constant fold inserted Squeeze and Transpose ops. // // if (modified) { // ORT_RETURN_IF_ERROR_SESSIONID_( diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index bae37adb19eef..ea2823916798e 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4606,8 +4606,213 @@ TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { } } } + +// Layout transform's cost function aggressively pushes down transposes with channel-first or channel-last perms. +// This can lead to a situation where a channel-fist/last Transpose gets stuck after being pushed down an Unsqueeze +// that makes the Transpose's perm no longer channel-first/last. This breaks the QDQ node units for both the +// Unsqueeze and the Transpose: DQ -> Unsqueeze -> Transpose -> Q. +// The transpose optimizer should insert a Q -> DQ pair between the Unsqueeze and Transpose nodes to fix both +// QDQ node units: DQ -> Unsqueeze -> Q[new] -> DQ[new] -> Transpose -> Q +TEST(TransposeOptimizerTests, LayoutTransformFixStuckTransposeWithoutDQ) { + Status status; + + // Using a sub-model extracted from a model that we tried to run with QNN EP. + auto model_uri = ORT_TSTR("testdata/layout_transform_fix_transpose_without_dq.qdq.onnx"); + + SessionOptions so; + + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // Set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 2) << "Should have 2 transposes remaining."; + + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + // All Transpose nodes should be in QDQ node units. + if (node.OpType() == "Transpose") { + for (auto cur_input = node.InputNodesBegin(), end = node.InputNodesEnd(); cur_input != end; ++cur_input) { + EXPECT_EQ(cur_input->OpType(), "DequantizeLinear"); + } + + for (auto cur_output = node.OutputNodesBegin(), end = node.OutputNodesEnd(); cur_output != end; ++cur_output) { + EXPECT_EQ(cur_output->OpType(), "QuantizeLinear"); + } + } + } +} + +// Tests the transpose optimizer's ability to constant fold inserted Transpose and Squeeze nodes. +// After the core transpose optimization loop, the test model contains the following "constant foldable" sequence: +// +// unsqueezed_transposed_weight --+--> Transpose ---> Squeeze ---> DequantizeLinear ---> Mul ---> ... +// | +// +--> DequantizeLinear --> Mul --> ... +// +// After constant-folding the Transpose and Squeeze nodes, the final model looks like: +// +// new_folded_weight ---> DequantizeLinear ---> Mul ---> ... +// unsqueezed_transposed_weight ---> DequantizeLinear ---> Mul ---> ... +TEST(TransposeOptimizerTests, LayoutTransformConstantFoldTransposeAndSqueeze) { + Status status; + + // The test model has a shared initializer that is unsqueezed and transposed in-place for one consumer. + // The other consumer gets a Transpose -> Squeeze sequence inserted before its input. + // This Transpose -> Squeeze sequence should get constant-folded. + auto model_uri = ORT_TSTR("testdata/layout_transform_const_folding.qdq.onnx"); + + SessionOptions so; + + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // Set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + // All Squeeze nodes should have been constant folded by transpose optimizer. + ASSERT_EQ(op_to_count["Squeeze"], 0) << "Should have 0 Squeeze nodes remaining."; + + // 1 transpose is constant-folded, 1 is canceled, and 1 remains. + ASSERT_EQ(op_to_count["Transpose"], 1) << "Should have 1 transpose remaining."; + + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + // All Transpose nodes should be in QDQ node units. + if (node.OpType() == "Transpose") { + for (auto cur_input = node.InputNodesBegin(), end = node.InputNodesEnd(); cur_input != end; ++cur_input) { + EXPECT_EQ(cur_input->OpType(), "DequantizeLinear"); + } + + for (auto cur_output = node.OutputNodesBegin(), end = node.OutputNodesEnd(); cur_output != end; ++cur_output) { + EXPECT_EQ(cur_output->OpType(), "QuantizeLinear"); + } + } + } +} #endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) +// Checks that a model that is not processed by the transpose optimizer produces the same +// results as the same model that undergoes transpose optimization with constant folding +// of Transpose and Squeeze nodes. +TEST(TransposeOptimizerTests, ConstantFoldTransposeAndSqueezeOutputCorrectness) { + // This test model has a shared initializer that is unsqueezed and transposed in-place for one consumer. + // The other consumer gets a Transpose -> Squeeze sequence inserted before its input. + // This Transpose -> Squeeze sequence should get constant-folded. + auto model_uri = ORT_TSTR("testdata/layout_transform_const_folding.qdq.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 3, 3}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + std::vector input1_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + OrtValue input0; + OrtValue input1; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input0_data, &input0); + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input1_data, &input1); + + NameMLValMap feeds{{"input0", input0}, {"input1", input1}}; + + std::vector output_names{"output0", "output1"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // We call the ONNX transpose optimizer directly to use a custom cost check function. + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + // Use a custom optimization cost check that aggressively pushes channel-last or channel-first transposes. + // This causes an existing transpose to be pushed through an op (Op1) with a shared initializer input. The other + // consumer (Op0) of the shared initializer will get a "constant-foldable" sequence between itself and its input. + // shared_const --+--> Transpose --> Squeeze --> Op0 + // | + // +--> Op1 + auto custom_cost_fn = + [](const alias_oto::api::GraphRef& /* graph */, + const alias_oto::api::NodeRef& /* node */, + const std::vector& perm, + const std::unordered_set& /* outputs_leading_to_transpose */) -> alias_oto::CostCheckResult { + if (perm == alias_oto::ChannelFirstToLastPerm(perm.size()) || + perm == alias_oto::ChannelLastToFirstPerm(perm.size())) { + return alias_oto::CostCheckResult::kPushTranspose; + } + + return alias_oto::CostCheckResult::kFallThrough; + }; + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph, /*provider_type*/ "", custom_cost_fn); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), "transpose_opt_updated_const_fold.onnx")); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Squeeze"], 0) << "The Squeeze nodes should have been folded."; + EXPECT_EQ(op_to_count["Transpose"], 1) << "1 inserted Transpose should be constant-folded. " + << "Only the pre-existing Transpose should remain."; + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); + ASSERT_THAT(fetches_orig[1].Get().DataAsSpan(), + testing::ContainerEq(fetches[1].Get().DataAsSpan())); +} + static void CheckSharedInitializerHandling(bool broadcast) { auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); diff --git a/onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx b/onnxruntime/test/testdata/layout_transform_const_folding.qdq.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d5fe768b47efa2a0c74282f54f694e44162bbaf5 GIT binary patch literal 4845 zcmb_gOK;;;6wY$8FOzgiJ^UNJjF+2mzWIG{wvyF@mWEiiD7m&Fc!^C%Ja(|*AIy>!OT>ylz;6iGkL&muJBUT4=ALu!`ObG<_nHQ_s%zRQ ze`L#=rJgIzr5-Qvy=#AN!#@ZbeO)sxNsvwy_2AHwh}HnsV-mA3j}+L7t)44}E=_b* zvsfV}x}}iA-c&V3Hl$tKn%LIWfjt?iec4h5O)h6kNVm(x&*L70ihQb?4RnF1=ew^C z2Fjpw|EzOcGi^h8S5uE{<<`B|b>pr+aDH0w1%gYehB5SDb!?A%l5Sg)=U2^JmqPxu z{H1$7aS<!OnYE+wh9B5OvYy-RN;*5ZX?-0zjQUqq(2NI{pr@OhCA6joNa+Vj=&qGmxcJFs}QvNWlg=Cb>r8|Cs__6WX45C{`o ztdVWu;%Wzr=|SAgkS`d&hX`^k6@%Le;AF*ZD^nvcWIrO*RIoq%9-QC@L=l0$OD#OpXlg>L9szmDX$daES4xc1z=_y%H7cZERpCFiTrwG!^ zK9pvIV2X!4FPfi`Q7FKQJ0OI}AQZy2>WwHN9|!`t5(cBRzwYJAB>BLVG1gFAw#xKfIIv5<&|iA;p@Ai(jm`>j~0q z5d0LBVfuzx9Q`VcWSd_diT!@M`BfMfFPLv12UQq9i0~7wiZq~p5elRTlU(OSiEWSM zGhD!uJM#B+V=Q@L-`7V2RXgAs48@_6n7o^ES7}6-X^N&8N=zRfcUC?v9gZLV^84q* z=~v%!Z3cBKA3b<*T}-{#a~;oA_Vu3Irpwq^=6m5zc{-hPP4Aerk&nH>5dJ#G5Mr0_ zsGs3v;Ofqv^Y=cyE=J01aRCUmh9x@9v|LjtpObqMZYP##B@s3Wl!~yBQ?ea~%)BCaw%#Q^*AS5}CD& zdiiqh<0lE#<#tGRCcA;(S>{CqHv@bHN`Q5HZRS4AZS>8MdVR^{=w`teL#F*8akI0s zJ9{2v9)pPsuF#8la4r90nrRi6fk%2GwbDorN^uxuR$^ajO3pS0Z1rN8W-2%T%i}Cm zZDIvV%$m3fYkhspkJY}RoBF=RYyHEMK7+(n09obTk3!=Svm~wqLw@5#QZ+n!uvQBi zWOZUF$EtF|hom*AMh?dt+R%(F@yaQF^s8e7yOz)#LB-RK|k_=||fAmcfx n&puD_*pf6c^uA)6Y|U9WZ2Zw`tax9qj4J@Hva0PXMvB6}FzgP6 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx b/onnxruntime/test/testdata/layout_transform_fix_transpose_without_dq.qdq.onnx new file mode 100644 index 0000000000000000000000000000000000000000..62c4e60c2c43b59e723c711e032f21bf27483c0f GIT binary patch literal 24927 zcmbTeXRu}0btcB9I40F2k0^|0hGvZtDN1CE6vYH0cS57%%lYP<&bcS&oWl+0<{Vzm z{W^C8Xheb_2||=CS)Qt~9BRfcSNW&?XYBGH`;wv{k^)BLyjQR8Is2@=&boW;z4i*< zexLfUp7;z5LCl|$L>9#VG|3QHG){!`a3Tpl(|YIB=|6tr6Hk5XTi^KX$3FL|i!g)} zQHVJ6FHiq7DhiUxGoSp_sZW3MnNz1WPJQCEa_aP_>0~yYA;Ks*p~`Z}#OdGoX$u~Y zg2~gr_0u~piNRq6&O*c!zx&gs7n6VA=|&bxMzQb=NgyQg*?;p9&p-F+U+nSB$4+-o zo&Fc+6N&KG!thUehyU#KZ~y##`0oVyZ+@ZmX9W7gUwAb4)9>>0V*Sp4mg;l=;^%u! zM0;@R^e@HvR~Pt?enFI99GAZz$=~}Y=st8rlc#@4_D?*qh(#eBf$`)Aga6rI|Df%2 z|8RNcw@!ZoNn-hPA3OJJ=T4pb_{OPEp58fi`a}Bv>SOW64_^3z5aFL!#ov#{AN-S) zAELy!Po4gy-|(vof969(`Y2_7{+FQrP=)^R)ahU0J*WSP<@v;aJ+TSY2V?!&ADakj zdSW|1_lX~0oq6K)uP3r#l*mpj$3)i?JMqagzj6AXr7@U9J~O%d_?dro`Zs_i1|`8T zo{1vKzbE?f9Em3I$sGB~?~nU^`t&Dp0?Q_nf6wW3_aACUg4gYdW>_7jR)}J-ozxjdL z{`gN8^67(9r$3a#uRJn8^TH3r_z1QBjNSVOk$-{qLlpb&&sn{X?D=1)+An|T7x*aU ze&Ls*{!sPCPX9`;I{g#R?TO!Bf^d{fCb{I#F4BMgGp(Qd^$%`-qSX)Xex}2(tkn-= z@(Y^!`_1_CzXa_^X!V!q`CqEl4`uz4TK%P{KTNAneD>1cIo-*`@24<0|NA@vWARKl zmnFnAU;EUrefpCZ{;}P?_L);3`@5fWKmGjJr~mHa%;!G+i8H_UfBfI4E_}3;d**jP zed=?c_)lMc=cx-H>K~u^gCFyF^Zz+@;Uo9XGr#j=7FVu);^_+?w&|Yv*FS!I;bZ^m z*$Y1x0-X8Gr$2t?lyKo6M*?RyH%@)(+^63D_@_=yf{Kq{_(!qCHD%*eMc?j18HI-H z@_)q4Q5#ByR3;a?%5^yNu<)a8fY?C6O7TS#!_Z>N^hUo|4Nl_<(YZ$|X3aX1AHGE8 zNrV@z{unN*ygE*2Hs8Eu;?HPrujk_>nLg_zoV%~-+u>`GltS?=GEEkVE)mPf7u^0# zGba>bCTsGvSjJ!I9y1Q)?|!q&hk^D?CihGXoU%18AYURW6jwWaef-0Rfo~*)?vd(t z!E>F^mw2kl7+jJW2*WfD;$*n`l2NrDxpAQ#P@vHYT!;^_(lvILoGw#iAZirsm1bHm z8z{~R)vcV%iAqpS`@6ktESJ4vfNzYKfvK;)ne5HCYA2iu^B7W{me6v{&a>V2Yi5NH z&#(DalfGCqE;DrO8_mgMWX^FQ+5SY zvsXBW=NDROMvKekzj!v7t5L#Sb%~*Si)<#)y5;VUmf_8tY4gs?3Jm8~0K&{8TRk*|;9Tz*7 z@Xb=pe_@(frPgt;FPkzYN&l~0%J7jJFkEeFT@x%M*c4mI|Qsc z2(FTh(z*PbbvQ3x+a_r8MV8sP7|%D`V>(ld`VPl2xfw~RW{k0AE*5>y#LGqB{$IYA zW)rB>7V(H1j5OrHxh>Ur6VNf9H|dn{W`Nh~qD|!TI4psOLPdpEh3HT+MAeujzNXn^ zjpiYzvh3W+7HQL>0A8|k(SEGeuJ9%W{BYaI*$r~3Sl!k4gzp_c$WUA!PPaS(&eq(e z${up_dUhov?sblH#YpnG+A?qhYYE@RxvW(P|KRn=iyKv3zMyMjzh+aonu@~ZKG^~9 zzA>-fWz(kV6uyXlf4)_bFBy853jK)~!3U4x`EbUP2*shh#XZ$Y3zH>wLxUK#HSNv}ZK9#C%Ke?~HVuJzOc2=%UdPGgaaq#~^f& zUZ`XZb>UGA7PXq}6cdo~H96iNM-{^%=o3o3o`7n(=bTycFOwa7nTE;-NH&Uits6@s z3YpaAHhVY{fi|o7)I0o|`0|4jLy;-ofVt+Zs#;-6;)*CE@q4Z5c5oZK(Ex8~u*)ga zfOLn(-i=1VVk%q9sZx9pgl=sRiCIXwoV7=57ybS&;7O3oL6y2ujy$aTq4ZR+&3_36 zsNveqQZ&f{n6!_-!-rxu2KnI<@M=3Q^8R?9%w7Wb6H@91B4SFJzBT{m8T{m75mc*; zb-q$E3fqs2RjpzacUtV(c_%)~zmwUV6%xgHAX6^hU8!T|Qro<{Z*f8Pt@N74nKVFA zo5>v08;hxv&4h{?ra~B(E*Db1MX}Or(ec&$Yubw8YRXrb{dHoRVgkC@AB-zyJO6{5 z$7C~-czRn&YSrRCV`;Zzoh^MCSD%r+i`0_7ti}6LU1d~%*Do_qh5cdhKizh+f+Bms z_dO?vC^~{JGyC;RvXd??#lXrE&A2wnUTN;TE(paFIy*>3*SDEWBC~l5yQYP`A{xC5 zQI#Hy6wo{q%uLq|D|7wvOF*|dtTjyOT24bHQfIz4?3PuUy`60A4Ehn-aNsq6i!b3; z49w;me4ALVB~bEuG2faD7;0DF7&rSDTX=Qb0O>wqwMb&^TPPbMo+TM)y?%VP^I}d1 z(Y;iGkzLgqbvh+6k@=f%7c+)-tX+~w3olrrMHP`yrM+`8+QK{G#LeV@ux_^dd3rW^ zxHWTS+jSdOOc!gK4!quX_{-F`@yaOpp`~^P7+Q(DVp(em?J_oPHu>GCfl9e_ zb7DBrRt6>_PZ?o*gZG$TD!?c$H=F{a%(A{c@_cN&(phIrN}1V3E2^cowMImd_IksO z#nc+TK3Fa7ZEJaEZrMpJYfgDNdw?d)Y6&1Z!xUNEY-dLv=VqEqF-2oGBzzbE!)QC7 zM1E8dkmuZXIFAU~KVPY%h~_qi}u5cTXxD$yuwi*srFM*9bkc-Sp&UFw?4eyD0H8VRmSx6?Cj4XqU00ddE-6ST?4n@KtM*RbU;DcZF3bx+y47J zsqiP6xH=@r1YrO91NRQNuf9UP=IDF$K(D)G!YQmn*WO2=X49ZK_dxI0_^j)^e@}pZH=oMqtbh{mOW;osG<9m#^JdQz|4Vk$U+G7uqDs z$^4V!S}8S6fm`J_;c{7w<#p?Lz%3};wEdDq_L+3uEx!5c+=Xz_uji*lHTNhHpcs=A zbYb_(l*rNM%;Usm2JD6kTx-;$@<9nEmso+^6%?N-vG>>=;g*4#H7%=TBY$;N4TH7z zI<$x!Z}O4=QdoSSaHAJ@8@!~$VzJ=OYQ`Wp)C&vL5G#c?VRpWvi}^u(eNMV~F-(t4 zU^HD7=q1L)D`HZV;vrlkJ6@HK(~0a5UQ9}2pno%TFtnBlGA$WpY_Mz_*0YOIKqKc8 zrq0!!JqWN+Hs`G4Tr=$FsH-K4lOLdRXtGRaHWL)z->QhrDvDaz^%3)zMwTQ}s+FYX z%Mo__rdWYJW&Z?#@^ckSELe<30uC^P*9}EU*HDQ;@gCbi80sz+E68xC2T#$(Du*VD z$gZzRBW7|nv9XnfSR^u*8<<2I5W7Eid`&dP2B1{gF*5_&5W*eowN5K<@-nDz&x?AB zFTaEud1@iM#RHvnE$XoSRf1ig76qz7ZizuMzlsNnU|v2>I=q8>E3ATcA}7kmO7kK?n@8!=zCpd#rp3Fgwk5!x}}`dp0TH)E*W&>xPBdZ zr0SrQ!k0fMdQPjZ!nZ1`$?(mIH2~ju6BUMqOl8US`x6N<#hhRLVHb(0XrKIgX{Q4e zJdh386RQd(Zk49~+5>UKC^7z4)h_AI;Q94^gk~=ADb>4a+iV}P?2_HG^VFQc)=KqK zxWKEu=_4KsibB1LLq|c$AtMwARi_N%>QfiD4+kY(e#yj#vp+17$Quk*cIV1V!bssQ zoTZyJA{0jo_!}}@Z$4`fcxQ7}CU$1EBzKL37Hv+KNu_n4tCR^?^D3%r(ikx=&Zl#j zOqSP5e3Wf_RxtkuBm7Y-!rh6I@<`u%RIpoke!w2gaNBk9c_8r;9g}0qTGa|p*F@n9 zT;0B`rXue7I+4`$w_^WaF)<`zPyi|jTj+rMvnF!3S-zDXpcQfUe-32iyogn93MKtK zl6_r2ZZ(=g&tDmT2h(=OIUcLaKJn7S4hJ5D1sRJ}TVjq> zW+O`RXzhPbuvtC$e#cp0?uq#(EJ~ehC%ed^+FqJoO$(|_mO8a)H`h_u$k3bV)MX0a zvf~(^R9RQQT{A0Yf)gl3w6TMQVL5QaFLQ+ES$thuwauIBKrZUah{^vW<6&FiR zyT_H=v0$9EWV|Gm4`W8XS3YDa2oo2Z(Wn_W?vp=Imt%y|u%?|G^U<+3C?Dw0x0rSM zU14ZvX21a$JRbRrSCVKKJf0G6mwI|A%$K(#>x-A#Jgo|SfvwW&}vsAXKy<3u1px=_Wjk>9xbx3`cswC1wJXM}8#e*q{20fT%n~0jX ztc3Ck7AmU}M!OMoUwj-M%x1Ve?v;yIG-fNA%@P|UC4*Nz#9ap6ywuBQ2|>8fQJyQ? z43yFw;pL6WGsz9{QDxn<4c4ZM3|{;0Taw0zbhcTO(I{Kn3dA@ppl=6hjnpD_tv>Pm z7~3iUq_w+OBa-!PC7%Lnsvd8Fo#fo+rQPm$du8L&?5$iTAlv2L>r6|I{6*|~K3(RI z>P3-?5!;~vnVC;6D|9i*k}E(?+ucb#Hp&5|L;IV_py6~7kJ&TzX`H!zahu-U(x3-r z2#t6PfnH)A+kB;waWc2E>#;Hp7&{wx6Z!tvDtT&f3Yndj6!w0=K^&kCT{BY>sY1LR z=%@Uq&JX3{S}eM37Qm7sk4vmH6w2>l1fYUX>#*WS%gHFf2)%4AV`*uEg*2m0soA4d zAZZX{w#_(l5i9Rt)}mxaV|5}xaRCOSY6j%Y_rxPb%)h!9m*>q~_S-F&FHd2qXn08v zfNWGi><~rAmotRSgAw5&xx;~}Uf##Obl?|+qZASEEN;R$`|8T(%yd+GI|@6G6L?V0 zH)=gvS(4Ks%!QVEl+C*M#8!e<$&sbk+is$>TvRKyA-Er@V};lm(Xk8jYc-C5k0b>t zbG1qmuU0qi3|XWqv0H4DExZFahS`mQi5}5MIRc&Srb|qB4!+|}{JN5(q?%iOluK-- z(y^>Mw~^y#oe`6G*K)0jx}#^=n(fgpA~q~y2__XeCimQ3@Mt!_nRb@oSp3WRt}Y3M zUJfaO-d7IRt+c zyzf)&c$^EB?Ucu*cjyG}wK91{Puu$yP$939uU;NqEYZYlXF0lZG2RN2K!BEiR0f@G zsdd60Z1nc|m9lsHa%3^08Zg|Aq=!r8VFuW&SMHiYmn)aIF{V-M2La}NZxO%cp=-n% zdao!ZC`q?p^o`z9yFCTGCjiO+$&Km(*4Ux1H27iCE~t{XEz+C4I*~E!310*sp4%vu z3d>C4dhUe`l4isvKN4hy2JFlVFvXDx4a4DP8PdqcBZbmu;UX*SLY{balkT@>E&0+= zHh*{+WvKN?p7{69qAPA;!ZwZi(=j=;ME zYF)JCKT4FsSV6%>P^6}Dn_1MGTg&T>Xjp6c^_4b)?fc>QROhgr2N@4x>sDiuAHezb zTJtKtzht#abY{l}ttED?9waU>{4kQvQk3HL(%Dck4*K@y_%0L>-PMA@FAiSKH0udT zflw7TXpCaZ!nc9wAa}F%u%flg_f}w9>!3h(of7oCl}lTl%Y(t;fXwdKZ)U5&%?8lY zE6g;%^%~SU$Hr)i!NfXZ_?(Per&)og6E~!rah_Anne!AUjP0xjN$c5ghImEW_#;vk zh=X1yLW!2buoI`3CtDBAFz&__5PO@pmFeWz{p?zGFT~t+JbW0@^E0wca+$vORGG<6S6G#Wy;8mO}U9|}XlxzW))-cE_ed`8gjkx6|FT-NRsCU5)vQP=oKF;B!UlR zggTFw`KngxDc9=6O=4Qj!;BTU4{ZusEo)=QV43Y)VOk3nQuu!qasnDi$OaZ zSOu7Q0gbqFOp3Rdw2=y#*v^Pww%+D$Gj0KKcBAHfbf=Gu<85wTB+l04+EdP!JTIE@ z=k;{4X4H_0v4u*Qqrq#*j-)x6o@NQMM7Rir8n*{YP?( zi4C`)&9ukZ=R$@iR#IfsUClfc%WNSx;#ivPXqfce;;PoGMH6m_?6Uc*dnbqT z`5^<(ne3q2oEmIxliMH$cyJ>n5I0uTLEKJ~=ZdYo5qn9j#Tz9zvy`%qE&V`|1%^>n zJ}H6-=SXnEy6iQZg>*`?YjE2pZG0Z9Zd4vH(4~!h>cwSYyRyO!d)T4zOy)JV(g7YX zoVBy_@R})A#UGX7p{CT(VKonAfVAp0fM@bmTaN{_fM)2EXgF!Ftor7-^=B8{QGBFS zC+<;Syl5R(;idH>ohj=EOxoV2irvl0sR` zejTtci?$-Rs;l(|ojpnaTXVg_Ctv59hMq{2rCS;&qvRuqhxXyN4rW&E?s6#m==!Xb z0p%fY)bI8R&6_zZs>fYR{N0CqlC8Z{w4$|G+FxKaLrUePbHsua$<0rx$sBO$;S$=y zI4tE?)LJ+W=XR&&>&dK21>I1FRcR$YmeZEc5eVg)f@W7jrmSs%T2fVC#ff;&JOHGa zw7K8V6qWSNJCy>kXDU#}VU(^>jE%uNqQ6>5n{}KpEZ&)9#?!GIaJqT+#w;+F#Cc&E zUGQlRX7*^r)jC%gimjAn8T05wf2>}$ReV_S=4p9qUMFQPDk+r;kIeTRWuTPO7R^8! z^J?{8-9(H@Qs;I>UYAB2zS0KLN$3?(&XnaaBYxQiY2EWh7eOh{H*C0>DOU!Z98BPc zxjh;xW)-vvY=WIgNp^%z)L1X<6_}WM+gV$xNc>gN93#aFr?rrk4id|+%rqP6CzIWm zQ~>gV`cv+8w8NeeZo(`0&K@RjFyd^-#U3*%dO?#Gr3MWd@&-+NoVkWbhjz2Mq;&G5 z?-I0Sf|YR_B(oeOVo+RwT6xvJl3N=mv@ICF+AMXQ?!8t^O1O3hUwe`72Mg(0g5o&) zmJH|XWkhG8O5ScQ?&Fj1#csa~qC#n9%}HciC8BV=)8uUnH1g6YPw=23bCIk`RH6;K z1k)f}eV$s9odcd55WuzQJ7PhC&(*5nb1AAOnLP z*Nz&mp}8+z<0iX;qAe71t;{qoTvpbc*(Ph%;=~#N8pvkD!aw28ct=R$3k|vPdR~IS z&?=qj=onE(l<%RkC$as&W?YyAx*XQi*Go~@9PFxj01tteHVc;K^>|(*(v(rJPpW>K zhg=9k@a-_(k#<%4`B;7p1R;HFH56_irtUrFssqvreVZvCy}hN23?a2~4TAgEhK;%2 zW82;-H4yN(6+Vc zHMGYJKT}Ac7^AoY3CdfVaijAUYO;x}wR5G)bXU4d-QpkD+V{rvHZf?mz-*FW+(Me9X*JFb502< z2a4}orH-u4W*VhTo+js7!LPr(C!}MY1K;KC=zEvvhIipdrCiNrX9O$8e7DN11W!^c z-)Z<&xhY~={)%{;e%dw^ky#4Q$l((jT^^%kz!52fYyo4|RO3-9Xd6|&7+UDyaadra zjmpyt!Z@sF$dPQ7mRWfIia-Eq5L-LJT1>y5r>>M(a)iF0Fd8W5h3@3yd}dop$oP92 zXAGuwrd=%-@8`#Km0Y+QO1uJwyF%~v${>0tWK+c{(a`hIP@Sbizr{>1w?>n+@8ODC zwdJg6wOZQw12aHUH6&t{6uMv-yuHeGQukKvhQijQ)^x6(CcEY8@Wz(wV=u>#6T4?O ztF7(g-qgYxks0V-}J2$?Oc%C0@PVEf;T+y(o!;c6>ax$<3K(P3*C&b3lw-NudRry7Rm``brN$D z_+pgEo6%G%Kf>%4?QSC*5kd31=C!N6bETEkMw;(enhPk=)ZOMma(rkT)ZT%gAHuvZ z`#sa0cbB!p@nWJ|C>43948~_IzVl-Kza&8{exAz$7j!8N4e>=~AGc#o75SWK@ zwFlm)&AZdQjD~Nh?6IZvYeZ@g?pH2Fjjv;?JdaUcStYs(^?Q>@mA9U+2bb|w0yI!v z38S|O&d{aRc`6+@uyc6A7`Wf8*0*VHBfm7~?wSR$Oiz;Dw-_c~U<%+@`4h^bCW)s* zk$^sSyMB8xaG>;lV;g%&uWTR>1txwXM=i<*a5y(70p;0(Zm`^}jF+4w z;rDa75@hx2X<$>x^%01ke$YQpt`M9wA>QWwr~ zYfP@XPw8^ZjN&HPNVdgdYOO`YPyufQ&3O>Qi!va^7xVMI*@mJbIyvuw@%0IGhNeml zBeFZ#2ePOf0XhLHRM>9ivF0_YSSHw=qCL1B|86tq<)iNY0E~&6Z%I{aUC&?(9N;Bw zJf;-rCd3~a{WQ@?if(g~uI#P)7&R;(oe)N=G69^_!+1*-3fo*_4Jrp-yB6tYI-}V) zV!Otz*lKlFsWl!*2AQFwY_NcMnfeLAUxU)Npj&dEW|a~z;xL5Lt?iBr3Sqp;c8L~K zA?-0m-e4m#+-!hLv`rZ2vOOm`8(V6p@{#7(vGw zJ6cFpCW+M$bGhP`0y%)c1a)x4Dh-JReNO6E_3uSUcvbSQ(B%ggI(&ce1Zi~4NrG2F zkMq=G$jY_e9mjBwTgwRY)fYp!lK$ezm)CO-Hk>qi?$PUOOYMS{^VcTn9HBZXgyZ!@ z9?c{6jceK%Ju~C=r$-jb_{=Ibhb=VCff09oW`y(DJX`1IwmGaX%W^2y1}-nrXgijw zj&5tr8mt{R85t!1A~$Tb!u#e9o=zx0p(+tZ*14Kw3um#cVDDy#HZsiX_98%NZLYC~ zs1|o~T~|0Wl&FPOqgg6t9!*vOaJ#$;L4YJzLi9S7l5jESpB~;47=wjI%mcYL!cZuN z%imjj@!1UDEbjJ2=L*ODjXcLiX?tA_*4?jWH_|b|b75?2$(3F#hD1aai5ZWrzZhk_;b!tc>$ZmrCU=HuToRuRdrGBNBLKI(7%u5Z1VLSZULL-q*7d^RfRP?Tph z2J>PLz!d0 zdNNMaTGqfe7tE#HyqQUwP0>|C{?}=9Z`kbuYb& zGy}8#OU~4CMeT=@Tcptnd-ZHMzhN|E8G6}Rhjew9)k_y>K#AT*D)9z%@Xazmm9B@6 zpVPj zJCK2&Sh*OP6i06Q;C@Eaih&Li5rTJkev*@sk}8$5?E_QEB@<%h&RE?nL_KR;D7!fu zfdovdTeEhR1@Ee^Nuvd>EcJ`P|LHvGDcSI|G~+L?qV$$i!p^)xa#Ks;a#X&y3B zS9B+iLC$msptQeB8C6&7tD83;{yR|zSE}g|c_ruB?R=(31i(&s?uJgKIHS;B4JiCy z-%aVTSx3!gNyVj}$L=v)t`Qfbxn};ZrQ}zo73=rp6@k4}sd>K6_zqVoQ76drb*D^J z=@`8~H1o{Hy%aawTD5PmrfsgG`s>wmBNgOQqMT4o)AX(G=_Oje?#`M*Cqa6}THOx+$u-50&6xyVE z#eAXME?>dc81<0n6Z4uaEoa)j8doj!(%-%reqKE4JDbIhTtx_~7=IY8VZ13tqfT+M z#h2BjDE8l=DY$6B%2B$CM@@vBD;A}C@+RTKOsbp|485==9kLPXJNi8LP|1~SA#^{M z;^n9EjFKHSbcJ{fCmq9bQF;06tf)Vk!u~yHvv#m2x3OSpZBpggzLDUHdkeYWpt66K zo-$AVL$s#edG7gc!HaM;UV=e;tQ3I@4E5E+w_MVf=cTG2u#%nqE`~?4)YHNFL0uV9 z;J%27bBm*U1#5 zyv5os4NUDc<99hT>T|V`HDpBd@Ti=o&0ABs`(^pA7Y5-KBYkGBg#}@Tri@2KO-J34 zyS~hias@!%Qy1-Sz~ydZyMxf5(eXxHTCXFz)$uXY?=?P0!}_FP_DBQ%sxWPPwl&F8 zS2&4e)a2ETNCr|%T3$7;KSwpwbxixS%5~N(!mrw)R&N}%>;x~9HuVQ=Ft5O6!5cF}Zdp9Wh$V5P>o7#dybt9R|1aW8e`RTj`O{(@*m#*YS zS9u%~Cjdd>vM6HI66Rw75v}IbPB-Eed?Ss`EbSHo2P5VWp&TT@0nyWqbe4!fLrlY;Q5% zqhMf6w~N;+i|ml{Hzr`(7#OX`47c#` z1n_A*C+wV0#gt4PpF{@`R*tl#6!YRrE?+XpOZBP*Z*C`PsAh{1>RUROPWD33ab#A3 z3xoJ1I3V$X%We-r8|&Q5g_+hGYfUCUjcvuOYng8^e1l}rV+Po~TnYN%C)0D8kYlfy zl-fsw4Z8}4kxhc%?kv-=Tz=>?)Fy$aZz{d>QV#!oE%#a-h|?vJ0r%XUyYD7eoA;}F zy#3OBrvR;Iv|dNy@*XekeXnXeQVJ^iV#OfvUg;R06zV)vVySmCTKBA)JaFbs<+4j%>uSe^}~ zHn#w9TZebqG^zK&NnI07sJH^tRv;|_S;)ZDo01kPX|{FM+-RI-Dxx0A#tWp*Eg9PE z5r(>HL~^pY-HM5Wu0iXw6j8xE`}rEIlJ!;tGS_CA!lbGs#~Ybq#vNcUvr1LdgX_+Y zJBc~ULS&DuZyDQ6lR+y_4Hh$TCtTsEz0@Nvu@J_t`-Ulcb*vN`>-MIr=G;I-&aj8( zj33)FAQ{OlLX@vG@ec>3<)Pv2+$cM$7;`&`o#I~j7{k%3f;WRo;T6uW;YY;)sITJf zoGhw$kXD-nYDJ+^B|wK#Eqj0o-DuH{nIk;gSsmcTLfTrSl2S5))LXjyJeO%wsGyPQ z#fmMI<1Lxn>#}T;F`3!6R%kj_u^Adu`Na)5MJknnF{ud^K%GrkV>d`32I8dH7FyzC zbU@u~VLptr&o5Qe&?=8v9EeQKdw+5ginBI?Kafph;pMk$t*Dw6CDNUscVR8HuzDGm+;P|q1ZCV`YkF7m8BVipsHSD*{JW) z^Uofo;Km~Ousm-LaGTIx-l|kKdI|7~HR}3Gg?H!Yt6erY0ojh&=0G-!K`WIN;rOmr zoI)ZHrWEcLrvV}(^QH2jZi~wlw6a;Wg}gzgf2-mx|Q>@c^BXAs^V8-RrH~ORLa||llKuG0WG|VMui^EwhVJ?djrPaZl zfu~EdB4Jz zzbi`bx9CtLxD>tKSC95?;pJOylVuS$_)e`)S%^Ji-wxu+MGjdN6EeE-0}UMoXolg^ z*RJOQvq<9f#Bp)~tL(?o-g!|M-}d{`*yFA3<9kNuZ#L4RkfDKjaFVjun2OM{SoUG# z`dL^o1>rCq<}-Ejt;aLWa@Hvku5h0DW|Ka~i^T=D6*}>#&EgFjWdu*yGRxZ;O1-LT zmRuBb$anK_C|Krv4)Ii78dcTZKJc~5AVLrUpiy#O-&r%Vh7ZZ7}BG+?zYicd$MDI^PL6aR@%XsHa{-ko!SDl%0jFj;h_o6y7V--Ji#I~Li|ZvP40R8&|5oBuZyAiC>Oua zl0F!}QQH!@a@vKb{D4oPgiejx_RV_e$|OgK*JJE-bHd+pvDrbsfGZ! zOK%)(C^ZQI-TK?rW~EIN!D@3hhS+QoP20h)W$I+>LgEFFN?j%t$~3fTW>2nmcjWZ~ zdWF<#p-LZJD_*d63aT$^gLB>aoXi^gYcUK$%rIJcR=hcf<4I%7#8x`p@Z}s;`Mzj` z&ECv99c9Y&)kB!xIN3KUC1PfR?0-|gEK}leW*}PI9*VPl# zN+nZyWgfNgV%}JMGqVB8BwN(3&|X;>Doj&*Ss~5F1lMCZm23zFARqmD(rx7=;XJ~c zP$zdMr|V$;s^8tqAJm?%EL%WUZ@RMHx?K33&?M@Lr1M+ex^w4HR)V)gk>fjR`>^aS z+%$CFTwIuSvur?^;AuIfb{4+bWUE3^$X`XOtnFa4io+@4|sCR9B?6vQfIw5A8MA|`*`;GSIa#Qcw5JUakZ8!o5)O$ zWcdhLd2NPvfLw-~9A%O=fJ}?9mE3II9_KcnEU4U~8v(E~;@`I4YiR4ndj$xz_Df8!rM{AjKX${(= z{>@%V5{shXnkHZn{;bpa7$s$Q%uTL$CGlk!mdMn0FVe9K*qyeJQj~wrye*EK{5pQG*vFO{?Xu0+H_9{2-^!?FQ~)ew2xT?X zc*Dt(@j7=&<-gU5y_SSjU(mk85*gi!LKVFUJZ^D)t0O7xK`r7S!ckhPSX}3o+2&~V zn^|-Pu6VJEB4-&+$PHFZQ-~?J{Mas`y)HtPT;02>C1Vr+C8QHAwroTPwo>-tzS1sO z%Q$XHNfvCYIQb_WR>56$O>V@pf*4ING|8(K3KUbwAeXBhrO7&U%%MXY3LA^n#&2kl(K% zdu3p^+@LWg&Q_Fv&15v^2r;w=}9%u5C}M z+U4)Z7d_GANImGWm&9AvXaYv}gc2uQe@DBs%;yHNizvRBy(-{qOdAn`8(p(ldYR^P zuaF!100|{*JFUht?9C}Fdi7f)L97DBmsSO#qwk~b^t~;5xG}^tEO;>Y{+NSLY2<~>N{pC>8z(SH*DZ!T3;Uw->MCckZFlv zwk(=1OWUZ8@mg>SLt$6-%yng9zE)6r1wB57$X%a9L5ZiYtnn_+Z%6j zm8!1ZB0KXeng9tJ7OjF%(@MVhk6^nXCHysEKe+devsBnPtRa$^DW(U=YVF)1TV*c~ zUEHs|Jr_e?5YtAUt4*@CT_GSDz?$x(`&yZA#Et5RpF{<6z+J?$xAV2^ z9^^le!L76~RcIeTB2~SJj%4{YV~i--g+y?5L+sz+XAVona_v*leP zx0h+258JhAPLu9Ew}y;!6-AG6+S!$Meb7gD;%>i=Ju6iuhpA-rmwEkKbJ1eIBzygf zwr{#F`Fgr(gNms|hr?&UPS@rphawhWTUlU}N2Z6Dr-1mmg3@cf}7xTijgM z^Ba<01Yzi%@8>6m$` zIyp#8a1XhRcY^F1ApiTtYGxs?8bGF%^U7egR_2>wQHdkGAIz2>- zIYC>?1^>_1Ym&^6E6qKGVU7fkEL%F*%?trx9tM&k0MxpWq^Yi&_fq1C9ragI91o?| z`BiUnWJR%*d!t;7J)iU#58KeAOs)5vP%4ZC9dHY=t!(}6KhKpEE5}(kyVJ2c{I`u8 zSMsGo&eb4x1^q7XZ!Zt{1$PPd{Mq=e&OUp)s-)kVU9HwaXMst1k828n@8n>t>xKu9KeY!+mY?qmg7(&H6%48XXEVMGrjLy?bWq(?fcTz)%!Bt z(>*;y&Y~e_#33b96eU@*99y}F4JSx)6C^->L4HDtj39vmjnEf$)i~ez`qVk6s?PVm zPoGdv?V60us-7fcl@YDCD-e>C^$Ql$AfvI_m)0nhN?pfz^R#sxI+75d9wZH_08|U! z^UWFt_m|z2w48paWJ-?sDl99b=4eoL+N(#$%Nu0Av1bL+p2q%SOqUmUoL&O+)_&*w z6y7e!q6)3VYGfEcB)>c83!-$nIiK%(UX4qadIkD>0jo8J$LYMGm-vnBEW1?qj@j%~ z!s5|`*P8kwHa0BPxhx3qpjsvs9Dr36J9d>?OHXs_e|zl}G%`355u}!$#2fYMjRKADw!y8?Yf;fuCA%?Xr>qsejixsqSE?Jji49_uLg>mT^;BaqUK3{x&CVkg_wa-*;M*TAnMOQt^(I^;J#$|4K3|>JUyBz3 zCdD!BaxijMn39GcbGR+q7twGNK=EFLkSFQ(hb(^} z<~j+-`y}Qp6nH_DahFy9lRp)Mv}ZoZdrn~ed|n~3n<@OC%Opxi_q|uE$N+H4(eAuc zW43yN8Q@ybsMR1(Fb1f_>nZO%N3GalgH|CowoT{h_=ly#4n>DY!a}s%QClV9-`kvo z3-%4-S3Ke#>Rv%Ios6#f#A-2a#sopnrWdeI$&5D>aq+~u@?)xh-lAg0%89c*s@71~ zY+;7ZeWzHoZ>m16p)>~F{A(rgpYF}d@Ma&}tR>G{6^FOaWNK1uCJ9!KD4^O67tMIR zw2SNIxzo3P_WH6N$u*Hxh2nK5Qvj~y8i{jp0GEjP)xr)ooegM&YS3C_MbMm@SH)_H zn#~qZN$LRW$boE_@r&Wudn>DwHWFWJ-YSg%Sxdodvamt#pnYbr;Y&Ay@1pBUlgW%t z5GrT0u?&Tj&7%&g6cU1ITT~V2nu;(FZPBzR8=4(nZXJb7q|UhZ*eStad9r^vq_toZ zaiG^D8j(oY^NvI-w%*FE6~<+&o;u;ZWsgH0=DP_|74wNRi0%Fix?q#95pk$M7{l-#qu9WK2Lq%Q6uk=Xdh`+YFIrf}Dogeni z`gc6s#NEts=++kBU)54s@2f7k?5!$k^Xt*q$&OUBpIg+eKw%+LUOK|@WLazx4N+a1 z?(?_G0P5RR(k`v8p{vdQux!f-8(*;;ZxzUG3dgH@V2~t3! zD@RghgIru~-;NhEsvvrr`h=sLoar+&H#7{$gI7|vCgcpQ4jMB5gDUHZAIVIh+lKbH ziz*}eeRm0M9O&V8_^?$j1bw07oA0)2C|g|}jEOYZRIgph8H(>CUhf&vCYO0@kM`a( zV(*#jT-|niPV%kVd{2P_Su41-F8wTS+jw0C3gU|Ew|MiF$`TL;XXWwL1rvzPh$j z?M!w{<4H_LUxwGg4SMs!;C+3UJS?GVWT$_e8`>HA9G}_fN{~!)&7+h&C>!UCQhFP# zozIO(5veaKL_p7phno6+>HC|^B6IQ0uQ7zB{*U}(MKPGK&{(_BvsY6u7cPKR&bSn! zRvU^*PH!R|ho?w3nE?wpuof{Bu__CZVwo!&;<$zEuWTY!&q~7+@IuEVnx>Qzxa8!U zWFe%-*L5xyzZV*FIJWbaNe2O{^6Sa~TDn*gH)sc#anVW`K|}Ph zPQ^;X#A=OT*K*lw2emEdgfElgro+g21@QzJwgVt*AgVW*xPJrCiWzZ35Bs(=N;;G{ zP5T?usjkZ`p6(|X*i)ZYuxoWr43BbgRsKDQmwm^iSg>02o=%k283 zHuIu7x2v8j-bD$-tObqI;Xz_826hdhvDn!DC8ILYvUdtt5((d#1~;T5aI!#+?H@PH zI^-1=wpA7@-a|pYA*Oq^p?XSZ7G3nOxXj9?X-N_$1#p>V_!5;msj_{}kT2lxM9uZ- zQC+Oz7W(t1`3h;_==bUciUsn^+_li|kE&RishuX-Z4}%=C*OD>F}1-Iusau(7$23m zOBd%f)1gzwPkH)-9TbA`tBhxNkPirw56!m_PvJ#0E?2h3m8nSGSIiReUO?Xa*z!-A zqqb{PS!Sy}10KTLcd3^oRqZvR?1=Vzr&kzN)r&WkI8`IdY*`Im4nxFlg~1s^*v89Z zBr_aGTY#esPYQ7;Qaf#j%Pk@&;@Z{7&=w7-4v_E^2A*|`>$XN%;lV5Do$k1VWCXA= zZ~KvV!*NgMtlaZk#RF|Ko*-_!4HP&`9_^#G)oxSWP#smU(WYv3dz#jW!>{0bnC_Qh0j~7~&0kn*CyDHZ zy19$Iu(76W(JZhM#(|J5Fm{&{Yav>wW#_DdhYKieT;5=?x|{+8kUfGkp5zW`c^rcm znHXiyW;|<1D-)(XchYExmZcc&v2d2#X!Le%5n2vgp_eW;AoCZ6*8M_XXOY6x*mOM( zP-%B7s6kX+_y=+`mu9j>#yT`P+tYQWru}*e?iD~<_=q8c;?9LVNQijaPe`jKuOihF zRcef4am%2Jmf$pp<#&pAv>j*PHj}#`Sd$2fw)5D!cT`eD&?d#nXKz6UE2ufw^8_8cJ7DSV<<5#sLwBZsuRul4D`Smkwh$ss# z;Q8N4=u1Ob$*veFXzHkZT{46?T#2#P$_6b{GEA zW-$OcaTl%|7?pklRAOaJ%eiQN-AOkRLqTQ6uECN(OKnmDa#K=CPn%H1H~jW!2SA28$Ks&;AFtN$!^G=_+b}{Hu#ReGa-`|?6UXGHN zXC2uynKA@Tibt*-L#n0^ChiD8%D+Q|^83HsL+~I;3p9fLE3}bTj#4ixy4L>9QTM&5 zlfk>Q%RlMQveX4Kx5tEe@cSr6(9M0ZlhQlPsiG|))AMYv>nf$vo~WiTygDV4s{d`d z(Fk(#A(Y+Uvgj(CDU~Qol%;%YL4hG^nKce}zq;Blrst*J8OY9@s)oPDQhl6G+{)UC zrD4^`jj%jR|C$f}7kG7Y$WcErK&gz$3N#EAbZ)e+3VWz1t`rXl5j;*DE(z5R&5Uki`pEJ5x;W}F; z!BRK%88=rr-55zhIvO$oAu*y5#Jf4uhQ@lnBNp1c>$9!^*4 z8(W1$#GPiWtia_@nojx3s7SAbEU$dD%MC+LDz)!9^~9cZJrn0#HHWQ%hh2T3{Q7M7ZeA4gwL4kOF$|7co*3y^HE#au6VTGzD12Y%v4&aJJM-Ld7P z*5CqBP{<e&0vK`0HgyNMY+iuMbD zzxgjTXt#4tiCCk%Z9QvcXkO!>ofrDz-gvgjZDL60f);R+YxdQULYE@emS73eJYt<- zg;a9Vh2q0VE2CzUM4Sh2XOxS$6&GgT-fCi0fd%e$Y$X(1^>?ThkjuOIvp~$o`Uh8P z@vzc9ZOuu|!y<6n2Z(ET$LztsEfXv^!!k6rpm9K8=Pn|&U1-}xf>BA4Fnlh^k-c@d zW_l$D3>#h7N@58v!j@A7ThJ_g?=UB{q^+d&tw^ieIiRL(_Ufn&Xl-{?Y>ToYh9TECwkS%wbA*A^j4WEh4(LhuNxANFlG z(pO8~(L}US%@LkFL0IMGm`Cd} z%(h;}8Zk-VGS?#ex}b9vhPI0)snzt|M74>hv1YZEA8khW@>JJ9;=E@Knf4(YJocE{l!AML1D&DsxWR`_ZY&xl!F z6>?P%?WLo|IwBs2*Li`rGjZ1($@!pQXu+;{;4-w6X$Wy{#5;M0a6LemDrQl8MPc(O zx$e}e2QsJL`qr&K@7A@P&ZLJdPaxWsm6^&UV|oBJ5bv24EIF>HiE>nP@ z;@7;20xzt;Vntz|GBryfZRgi%Ullq?I=Q7YxqDGC7bwPyJ|B;H+rMYgBrLXGjPSnky)W^#X)!=q!9`}X5-+mDRECn-N1&)oSZ$6h~@_9H)elouQHvt~pwwIp>!~C%M(~oI>_6ecLzqS8wJPvsL{2xs` z=VklFPqTe6%W`L*96S6lz|)@(@apfx{e5f>-+VmG_mPA@T@qe@kEPf?mZ1Ge%AbZf z|GTklAD_pI6OV=4J~DeAth-Np^(NtCH-*P1{$Txlwgf*5Gx_wy?>+nF$uCX(VKfUQ f@$f@n>$6WyJ~#12&&u;RKsJs-7?i@Fed>P!DeHVi literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py b/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py new file mode 100644 index 0000000000000..34b2ff923b18f --- /dev/null +++ b/onnxruntime/test/testdata/make_qdq_layout_transform_const_folding.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.shape_inference import quant_pre_process + + +class DataReader(CalibrationDataReader): + def __init__(self, model_path: str): + self.enum_data = None + + # Use inference session to get input shape. + session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + inputs = session.get_inputs() + + self.data_list = [] + + # Generate 10 random float32 inputs + for _ in range(10): + input_data = {inp.name: np.random.random(inp.shape).astype(np.float32) for inp in inputs} + self.data_list.append(input_data) + + self.datasize = len(self.data_list) + + def get_next(self): + if self.enum_data is None: + self.enum_data = iter(self.data_list) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None + + +if __name__ == "__main__": + """ + Creates a QDQ model with a shared initializer. + The transpose optimizer will generate a (weight -> Transpose -> Squeeze) sequence that can be constant folded + by the tranpose optimizer itself. + """ + shape = (1, 3, 3, 3) + + input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, shape) + input1 = onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.FLOAT, shape) + output0 = onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None) + output1 = onnx.helper.make_tensor_value_info("output1", onnx.TensorProto.FLOAT, None) + + # Shared weight (will be unsqueezed and fed into a Squeeze op by layout transformation). + const_1_weight = onnx.numpy_helper.from_array(np.array([1.0] * shape[-1], dtype=np.float32), "const_1_weight") + + # Transpose with channel-first perm + transpose_node = onnx.helper.make_node( + "Transpose", ["input0"], ["transpose_out"], name="transpose_node", perm=(0, 3, 1, 2) + ) + + # Mul0 + mul0_node = onnx.helper.make_node("Mul", ["transpose_out", "const_1_weight"], ["mul0_out"], name="mul0_node") + + # Mul1 + mul1_node = onnx.helper.make_node("Mul", ["input1", "const_1_weight"], ["output1"], name="mul1_node") + + # Conv0 + conv_w_shape = (1, 3, 2, 2) + conv_weight_data = np.random.normal(-1.0, 1.0, conv_w_shape).astype(np.float32) + conv_weight = onnx.numpy_helper.from_array(conv_weight_data, "conv_weight") + conv_node = onnx.helper.make_node("Conv", ["mul0_out", "conv_weight"], ["output0"], name="conv_node") + + graph = onnx.helper.make_graph( + [transpose_node, mul0_node, mul1_node, conv_node], + "layout_transform_const_folding", + [input0, input1], + [output0, output1], + initializer=[const_1_weight, conv_weight], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 19), + ] + f32_model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + print("[INFO]: Running onnx.checker on f32 model") + f32_model = onnx.shape_inference.infer_shapes(f32_model) + onnx.checker.check_model(f32_model, True) + f32_model_path = "layout_transform_const_folding.f32.onnx" + + print(f"[INFO]: Saving {f32_model_path}") + onnx.save_model(f32_model, f32_model_path) + + # Quantize model + qdq_model_path = "layout_transform_const_folding.qdq.onnx" + print("[INFO]: Creating QDQ model") + quantize_static( + f32_model_path, + qdq_model_path, + DataReader(f32_model_path), + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + op_types_to_quantize=[node.op_type for node in f32_model.graph.node], + extra_options={"DedicatedQDQPair": True, "ForceQuantizeNoInputCheck": True}, + ) + quant_pre_process(qdq_model_path, qdq_model_path) + qdq_model = onnx.load_model(qdq_model_path) + onnx.checker.check_model(qdq_model, True) + onnx.save_model(qdq_model, qdq_model_path) + print(f"[INFO]: Created QDQ model {qdq_model_path}")