Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VitisAI] fix throw on dfs #23678

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) {
graph.RemoveInitializedTensor(tensor_name);
};
the_global_api.graph_reverse_dfs_from_preemp = vaip::graph_reverse_dfs_from;
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,75 @@
return ret;
}

// copied from graph.cc, trying to exit the function early as leave function may change the validity of the graph
void graph_reverse_dfs_from(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop) {
using WorkEntry = std::pair<const Node*, bool>; // bool represents leave or not
InlinedVector<WorkEntry> stack;
stack.reserve(from.size());
for (auto node : from) {
stack.emplace_back(node, false);
}

InlinedVector<bool> visited(graph.MaxNodeIndex(), false);
while (!stack.empty()) {
const WorkEntry last_entry = stack.back();
stack.pop_back();

if (last_entry.first == nullptr) {
continue;
}
const Node& n = *last_entry.first;

if (last_entry.second) {
// leave node
if (leave(&n)) {
return;
}
continue;
}

if (visited[n.Index()]) continue;

visited[n.Index()] = true;

if (enter) {
if (enter(&n)) {
return;
}
}
if (leave) stack.emplace_back(&n, true);

if (comp) {
InlinedVector<const Node*> sorted_nodes;
for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
if (stop && stop(&n, &(*iter))) continue;
sorted_nodes.push_back(&(*iter));
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);

Check warning on line 142 in onnxruntime/core/providers/vitisai/imp/graph.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for sort [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vitisai/imp/graph.cc:142: Add #include <algorithm> for sort [build/include_what_you_use] [4]
for (const auto* in : sorted_nodes) {
const NodeIndex idx = in->Index();
if (!visited[idx]) {
stack.emplace_back(in, false);
}
}
} else {
for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
if (stop && stop(&n, &(*iter))) continue;
const NodeIndex idx = (*iter).Index();
if (!visited[idx]) {
stack.emplace_back(graph.GetNode(idx), false);
}
}
}
}
}

void graph_remove_node(Graph& graph, const NodeInput& node_input) {
if (node_input.node == nullptr && node_input.node_arg != nullptr) {
assert(node_input.node_arg->Exists());
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#pragma once
#include "./node.h"
#include "vaip/my_ort.h"
#include <gsl/gsl>

Check warning on line 6 in onnxruntime/core/providers/vitisai/include/vaip/graph.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: graph.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/vitisai/include/vaip/graph.h:6: Found C system header after other header. Should be: graph.h, c system, c++ system, other. [build/include_order] [4]
#include <functional>

Check warning on line 7 in onnxruntime/core/providers/vitisai/include/vaip/graph.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: graph.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/vitisai/include/vaip/graph.h:7: Found C++ system header after other header. Should be: graph.h, c system, c++ system, other. [build/include_order] [4]
namespace vaip {
using namespace onnxruntime;

Expand All @@ -16,4 +18,12 @@
const std::vector<std::string>& inputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& constant_initializers);
Model* model_clone(const Model& original_model, int64_t external_data_threshold);

void graph_reverse_dfs_from(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop);
} // namespace vaip
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (13u)
#define VAIP_ORT_API_MAJOR (14u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
Expand Down Expand Up @@ -243,6 +243,13 @@ struct OrtApiForVaip {
const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data); // [101]
void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102]
void (*graph_reverse_dfs_from_preemp)(
const Graph& graph, gsl::span<const Node* const> from,
const std::function<bool(const Node*)>& enter,
const std::function<bool(const Node*)>& leave,
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node* from, const Node* to)>&
stop); // [103]
};

#ifndef USE_VITISAI
Expand Down
Loading