From a9e3d41e7a935c00551761e14182108cea303dd9 Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Thu, 18 Jan 2024 23:44:43 +0200 Subject: [PATCH] Sort roots --- .../providers/cpu/ml/tree_ensemble_common.h | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 8f847fe66aa73..986b16eb0c841 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -273,6 +273,7 @@ Status TreeEnsembleCommon::Init( // Let's construct nodes_ such that the false branch is always the next element in nodes_. // updated_mapping will translates the old position of each node to the new node position in nodes_. std::vector updated_mapping(nodes_treeids.size(), 0); + std::vector> roots_index; int64_t previous_tree_id = -1; for (i = 0; i < n_nodes_; ++i) { if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { @@ -281,11 +282,33 @@ Status TreeEnsembleCommon::Init( size_t root_position = AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); - roots_.push_back(&nodes_[root_position]); + roots_index.push_back({i, root_position}); previous_tree_id = tree_id; } } + sort(roots_index.begin(), roots_index.end(), [&]( + std::pair left, std::pair right){ + if (nodes_[left.second].feature_id == nodes_[right.second].feature_id) { + return nodes_[left.second].value_or_unique_weight < nodes_[right.second].value_or_unique_weight; + } + return nodes_[left.second].feature_id < nodes_[right.second].feature_id; + }); + + updated_mapping.fill(0); + nodes_.clear(); + nodes_.reserve(limit); + + previous_tree_id = -1; + for (auto& curr : roots_index) { + int64_t tree_id = node_tree_ids[curr.first].tree_id; + size_t root_position = + AddNodes(curr.first, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, + nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + roots_.push_back(&nodes_[root_position]); + previous_tree_id = tree_id; + } + n_trees_ = roots_.size(); if (((int64_t)nodes_.size()) != n_nodes_) { ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");