Skip to content

Commit

Permalink
New calculation for small trees
Browse files Browse the repository at this point in the history
  • Loading branch information
bili2002 committed Jan 18, 2024
1 parent 19a772c commit b12805d
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
protected:
std::vector<ThresholdType> base_values_;
std::vector<TreeNodeElement<ThresholdType>> nodes_;
std::vector<size_t> tree_size;
mutable std::vector<bool> compared;
// Type of weights should be a vector of OutputType. Onnx specifications says it must be float.
// Lightgbm requires a double to do the summation of all trees predictions. That's why
// `ThresholdType` is used as well for output type (double as well for lightgbm) and not `OutputType`.
Expand Down Expand Up @@ -80,6 +82,8 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
const std::vector<ThresholdType>& target_class_weights_as_tensor);

protected:
TreeNodeElement<ThresholdType>* findAnswer(TreeNodeElement<ThresholdType>* root,
const InputType* x_data) const;
TreeNodeElement<ThresholdType>* ProcessTreeNodeLeave(TreeNodeElement<ThresholdType>* root,
const InputType* x_data) const;

Expand All @@ -92,6 +96,9 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
size_t calcTreesSize(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids);

};

template <typename InputType, typename ThresholdType, typename OutputType>
Expand Down Expand Up @@ -216,6 +223,8 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
node_tree_ids.reserve(limit);
nodes_.clear();
nodes_.reserve(limit);
tree_size.reserve(limit);
compared.reserve(limit);
roots_.clear();
std::unordered_map<TreeNodeElementId, size_t, TreeNodeElementId::hash_fn> node_tree_ids_map;
node_tree_ids_map.reserve(limit);
Expand Down Expand Up @@ -389,9 +398,11 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
tree_size[node_pos] = 1 + tree_size[node_pos + 1] + tree_size[true_branch];
} else {
nodes_[node_pos].truenode_or_weight.weight_data.weight = 0;
nodes_[node_pos].truenode_or_weight.weight_data.n_weights = 0;
tree_size[node_pos] = 1;
}
return node_pos;
}
Expand Down Expand Up @@ -689,6 +700,27 @@ inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
inline bool _isnan_(int32_t) { return false; }

template <typename InputType, typename ThresholdType, typename OutputType>
TreeNodeElement<ThresholdType>*
TreeEnsembleCommon<InputType, ThresholdType, OutputType>::findAnswer(
TreeNodeElement<ThresholdType>* root, const InputType* x_data) const {
// std::cout<<"brbrb "<<std::endl;
size_t rootIdx = root - &nodes_[0];
size_t n = (size_t)tree_size[rootIdx];

auto temp = root;
for (size_t i = 0; i < n; i++, temp++) {
InputType val = x_data[temp->feature_id];
compared[i + rootIdx] = val <= temp->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val));
}

while (root->is_not_leaf()) {
root = compared[root - &nodes_[0]] ? root->truenode_or_weight.ptr : root + 1;
}

return root;
}

template <typename InputType, typename ThresholdType, typename OutputType>
TreeNodeElement<ThresholdType>*
TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
Expand All @@ -699,13 +731,23 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
case NODE_MODE::BRANCH_LEQ:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
size_t idx = root - &nodes_[0];
if (tree_size[idx] <= 5) { //make constant
return findAnswer(root, x_data);
}

val = x_data[root->feature_id];
root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
size_t idx = root - &nodes_[0];
if (tree_size[idx] <= 16) { //make constant
return findAnswer(root, x_data);
}

val = x_data[root->feature_id];
root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1;
}
Expand Down

0 comments on commit b12805d

Please sign in to comment.