diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 8f847fe66aa73..c48aea32507fb 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -42,6 +42,8 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { protected: std::vector base_values_; std::vector> nodes_; + std::vector tree_size; + mutable std::vector compared; // Type of weights should be a vector of OutputType. Onnx specifications says it must be float. // Lightgbm requires a double to do the summation of all trees predictions. That's why // `ThresholdType` is used as well for output type (double as well for lightgbm) and not `OutputType`. @@ -80,6 +82,8 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { const std::vector& target_class_weights_as_tensor); protected: + TreeNodeElement* findAnswer(TreeNodeElement* root, + const InputType* x_data) const; TreeNodeElement* ProcessTreeNodeLeave(TreeNodeElement* root, const InputType* x_data) const; @@ -92,6 +96,9 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { const std::vector& nodes_values_as_tensor, const std::vector& node_values, const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, const InlinedVector& node_tree_ids); + size_t calcTreesSize(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids); + }; template @@ -216,6 +223,8 @@ Status TreeEnsembleCommon::Init( node_tree_ids.reserve(limit); nodes_.clear(); nodes_.reserve(limit); + tree_size.reserve(limit); + compared.reserve(limit); roots_.clear(); std::unordered_map node_tree_ids_map; node_tree_ids_map.reserve(limit); @@ -389,9 +398,11 @@ size_t TreeEnsembleCommon::AddNodes( // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; + tree_size[node_pos] = 1 + tree_size[node_pos + 1] + tree_size[true_branch]; } else { nodes_[node_pos].truenode_or_weight.weight_data.weight = 0; nodes_[node_pos].truenode_or_weight.weight_data.n_weights = 0; + tree_size[node_pos] = 1; } return node_pos; } @@ -689,6 +700,27 @@ inline bool _isnan_(double x) { return std::isnan(x); } inline bool _isnan_(int64_t) { return false; } inline bool _isnan_(int32_t) { return false; } +template +TreeNodeElement* +TreeEnsembleCommon::findAnswer( + TreeNodeElement* root, const InputType* x_data) const { + // std::cout<<"brbrb "<feature_id]; + compared[i + rootIdx] = val <= temp->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val)); + } + + while (root->is_not_leaf()) { + root = compared[root - &nodes_[0]] ? root->truenode_or_weight.ptr : root + 1; + } + + return root; +} + template TreeNodeElement* TreeEnsembleCommon::ProcessTreeNodeLeave( @@ -699,6 +731,11 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( case NODE_MODE::BRANCH_LEQ: if (has_missing_tracks_) { while (root->is_not_leaf()) { + size_t idx = root - &nodes_[0]; + if (tree_size[idx] <= 5) { //make constant + return findAnswer(root, x_data); + } + val = x_data[root->feature_id]; root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) ? root->truenode_or_weight.ptr @@ -706,6 +743,11 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { while (root->is_not_leaf()) { + size_t idx = root - &nodes_[0]; + if (tree_size[idx] <= 16) { //make constant + return findAnswer(root, x_data); + } + val = x_data[root->feature_id]; root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; }