Skip to content

Commit

Permalink
Sort roots
Browse files Browse the repository at this point in the history
  • Loading branch information
bili2002 committed Jan 18, 2024
1 parent 19a772c commit a9e3d41
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
std::vector<std::pair<size_t, size_t>> roots_index;
int64_t previous_tree_id = -1;
for (i = 0; i < n_nodes_; ++i) {
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
Expand All @@ -281,11 +282,33 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
roots_.push_back(&nodes_[root_position]);
roots_index.push_back({i, root_position});
previous_tree_id = tree_id;
}
}

sort(roots_index.begin(), roots_index.end(), [&](
std::pair<size_t, size_t> left, std::pair<size_t, size_t> right){
if (nodes_[left.second].feature_id == nodes_[right.second].feature_id) {
return nodes_[left.second].value_or_unique_weight < nodes_[right.second].value_or_unique_weight;
}
return nodes_[left.second].feature_id < nodes_[right.second].feature_id;
});

updated_mapping.fill(0);
nodes_.clear();
nodes_.reserve(limit);

previous_tree_id = -1;
for (auto& curr : roots_index) {
int64_t tree_id = node_tree_ids[curr.first].tree_id;
size_t root_position =
AddNodes(curr.first, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}

n_trees_ = roots_.size();
if (((int64_t)nodes_.size()) != n_nodes_) {
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
Expand Down

0 comments on commit a9e3d41

Please sign in to comment.