Skip to content

Commit

Permalink
[VitisAI] fix throw on dfs (#23678)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add customized version for reverse dfs which can exit early if the graph
is altered.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is requested by MSFT. We used to use throw exception to exit
reverse dfs and this is not preferred by MSFT.

Co-authored-by: Yueqing Zhang <[email protected]>
  • Loading branch information
BoarQing and Yueqing Zhang authored Feb 14, 2025
1 parent c420052 commit 9d95e81
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
1 change: 1 addition & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) {
graph.RemoveInitializedTensor(tensor_name);
};
the_global_api.graph_reverse_dfs_from_preemp = vaip::graph_reverse_dfs_from;
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,75 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o
return ret;
}

// copied from graph.cc, trying to exit the function early as leave function may change the validity of the graph
void graph_reverse_dfs_from(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop) {
using WorkEntry = std::pair<const Node*, bool>; // bool represents leave or not
InlinedVector<WorkEntry> stack;
stack.reserve(from.size());
for (auto node : from) {
stack.emplace_back(node, false);
}

InlinedVector<bool> visited(graph.MaxNodeIndex(), false);
while (!stack.empty()) {
const WorkEntry last_entry = stack.back();
stack.pop_back();

if (last_entry.first == nullptr) {
continue;
}
const Node& n = *last_entry.first;

if (last_entry.second) {
// leave node
if (leave(&n)) {
return;
}
continue;
}

if (visited[n.Index()]) continue;

visited[n.Index()] = true;

if (enter) {
if (enter(&n)) {
return;
}
}
if (leave) stack.emplace_back(&n, true);

if (comp) {
InlinedVector<const Node*> sorted_nodes;
for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
if (stop && stop(&n, &(*iter))) continue;
sorted_nodes.push_back(&(*iter));
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
for (const auto* in : sorted_nodes) {
const NodeIndex idx = in->Index();
if (!visited[idx]) {
stack.emplace_back(in, false);
}
}
} else {
for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
if (stop && stop(&n, &(*iter))) continue;
const NodeIndex idx = (*iter).Index();
if (!visited[idx]) {
stack.emplace_back(graph.GetNode(idx), false);
}
}
}
}
}

void graph_remove_node(Graph& graph, const NodeInput& node_input) {
if (node_input.node == nullptr && node_input.node_arg != nullptr) {
assert(node_input.node_arg->Exists());
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#pragma once
#include "./node.h"
#include "vaip/my_ort.h"
#include <gsl/gsl>
#include <functional>
namespace vaip {
using namespace onnxruntime;

Expand All @@ -16,4 +18,12 @@ Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_ty
const std::vector<std::string>& inputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& constant_initializers);
Model* model_clone(const Model& original_model, int64_t external_data_threshold);

void graph_reverse_dfs_from(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop);
} // namespace vaip
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (13u)
#define VAIP_ORT_API_MAJOR (14u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
Expand Down Expand Up @@ -243,6 +243,13 @@ struct OrtApiForVaip {
const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data); // [101]
void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102]
void (*graph_reverse_dfs_from_preemp)(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop); // [103]
};

#ifndef USE_VITISAI
Expand Down

0 comments on commit 9d95e81

Please sign in to comment.