Skip to content

Commit

Permalink
Fixed for same_mode_=true
Browse files Browse the repository at this point in the history
  • Loading branch information
bili2002 committed Dec 19, 2023
1 parent 29a96da commit 2e79d4f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
21 changes: 5 additions & 16 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,10 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
TreeNodeElement<ThresholdType> node;
auto falsenode_id = falsenode_ids[i];
auto truenode_id = truenode_ids[i];
if (tree_size[falsenode_id] < tree_size[truenode_id]) {
std::swap(falsenode_id, truenode_id );
if (tree_size[falsenode_id] < tree_size[truenode_id] && !same_mode_) {
std::swap(falsenode_id, truenode_id);
node.flags = static_cast<uint8_t>(ReverseNodeMode(cmodes[i]));
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
else {
node.flags = static_cast<uint8_t>(cmodes[i]);
Expand All @@ -408,7 +409,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
node.flags ^= static_cast<uint8_t>(MissingTrack::kTrue);
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
Expand Down Expand Up @@ -733,19 +734,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
if (same_mode_) {
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1;
}
}
TREE_FIND_VALUE(<=)
break;
case NODE_MODE::BRANCH_LT:
TREE_FIND_VALUE(<)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3000,7 +3000,7 @@ TEST(InferenceSessionTests, Bench) {
// Configure RunOptions
RunOptions run_options;

const int MAX_ITER = 10;
const int MAX_ITER = 1;
std::vector<float> times = {};

for(size_t ITER = 0; ITER < MAX_ITER; ITER ++) {
Expand Down

0 comments on commit 2e79d4f

Please sign in to comment.