Skip to content

Commit

Permalink
Fix bug with leaves folding
Browse files Browse the repository at this point in the history
  • Loading branch information
bili2002 committed May 30, 2024
1 parent 0dac1ef commit 091df95
Showing 1 changed file with 71 additions and 47 deletions.
118 changes: 71 additions & 47 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "core/platform/threadpool.h"
#include "tree_ensemble_helper.h"

#include <algorithm>

namespace onnxruntime {
namespace ml {
namespace detail {
Expand Down Expand Up @@ -87,15 +89,17 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const InlinedVector<NODE_MODE>& cmodes,
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true);
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
};

template <typename InputType, typename ThresholdType, typename OutputType>
Expand Down Expand Up @@ -274,6 +278,16 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
Expand All @@ -284,23 +298,14 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}
n_trees_ = roots_.size();

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

TreeNodeElementId ind;
SparseValue<ThresholdType> w;
size_t indi;
Expand Down Expand Up @@ -343,24 +348,38 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(

template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true) {
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Leaves have values set at 0
if (cmodes[left_id] != cmodes[right_id]
|| nodes_featureids[left_id] != nodes_featureids[right_id]
|| (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id])
|| (nodes_values_as_tensor.empty() && static_cast<ThresholdType>(node_values[left_id]) != static_cast<ThresholdType>(node_values[right_id]))
|| (!nodes_missing_value_tracks_true.empty() && nodes_missing_value_tracks_true[left_id] != nodes_missing_value_tracks_true[right_id])) {
|| (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {
return false;
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
return true;
auto left_tree_node = node_tree_ids[left_id];
auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(left_tree_node, uint32_t(0)))->second;

auto right_tree_node = node_tree_ids[right_id];
auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(right_tree_node, uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
}
else {
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
}
}

return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true)
&& CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true);
return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)
&& CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
}

inline void UpdateThreshold(double val, double& mask) {
Expand All @@ -373,13 +392,17 @@ inline void UpdateThreshold(float val, float& mask) {
mask = *reinterpret_cast<float*>(&new_mask);
}

#define BITCOUNT(T) int64_t(sizeof(T) * 8)
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T))

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -403,18 +426,13 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
}

node.value_or_unique_weight = 0;
if (node.flags == NODE_MODE::BRANCH_EQ && (nodes_values_as_tensor.empty() ?
static_cast<ThresholdType>(node_values[i]) > 0 && static_cast<ThresholdType>(node_values[i]) <= sizeof(ThresholdType) * 8 :
nodes_values_as_tensor[i] > 0 && nodes_values_as_tensor[i] <= sizeof(ThresholdType) * 8) ) {
const auto node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_SM;
UpdateThreshold(
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i],
node.value_or_unique_weight
);
}
else {
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
node.value_or_unique_weight = node_threshold;
}

if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
Expand All @@ -424,29 +442,31 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
if (nodes_[node_pos].is_not_leaf()) {
auto falsenode_id = falsenode_ids[i];
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_SM) {
auto falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
(nodes_values_as_tensor.empty() ?
static_cast<ThresholdType>(node_values[falsenode_id]) > 0 && static_cast<ThresholdType>(node_values[falsenode_id]) <= sizeof(ThresholdType) * 8 :
nodes_values_as_tensor[falsenode_id] > 0 && nodes_values_as_tensor[falsenode_id] <= sizeof(ThresholdType) * 8) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true)) {
UpdateThreshold(
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id],
nodes_[node_pos].value_or_unique_weight
);
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {

UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
falsenode_id = falsenode_ids[falsenode_id];
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
}
}

size_t false_branch =
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down Expand Up @@ -746,11 +766,15 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
}

inline bool SetMembershipCheck(double val, double mask) {
return (((1ll << (static_cast<uint32_t>(val) - 1)) & *reinterpret_cast<uint64_t*>(&mask)) != 0);
auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, double)
&& (((1ll << (val_as_int - 1)) & *reinterpret_cast<uint64_t*>(&mask)) != 0);
}

inline bool SetMembershipCheck(float val, float mask) {
return (((1 << (static_cast<uint32_t>(val) - 1)) & *reinterpret_cast<uint32_t*>(&mask)) != 0);
auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, float)
&& (((1ll << (val_as_int - 1)) & *reinterpret_cast<uint32_t*>(&mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
Expand Down Expand Up @@ -799,14 +823,14 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = ((val >= 1 && SetMembershipCheck(val, root->value_or_unique_weight)) || (root->is_missing_track_true() && _isnan_(val)))
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (val >= 1 && SetMembershipCheck(val, root->value_or_unique_weight)) ? root->truenode_or_weight.ptr : root + 1;
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE::LEAF:
Expand Down Expand Up @@ -843,7 +867,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
: root + 1;
break;
case NODE_MODE::BRANCH_SM:
root = ((val >= 1 && SetMembershipCheck(val, root->value_or_unique_weight)) || (root->is_missing_track_true() && _isnan_(val)))
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
Expand Down

0 comments on commit 091df95

Please sign in to comment.