Skip to content

Commit

Permalink
Implement best first search (#254)
Browse files Browse the repository at this point in the history
* first implementation and tests

* add docs and minor changes

* minor change

* minor change
  • Loading branch information
pradkrish authored Dec 23, 2022
1 parent c9ab362 commit 5828120
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
89 changes: 89 additions & 0 deletions include/Graph/Graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,20 @@ namespace CXXGRAPH
* errorMessage: "" if no error ELSE report the encountered error
*/
virtual const MstResult kruskal() const;
/**
* \brief
* Function runs the best first search algorithm over the graph
* using an evaluation function to decide which adjacent node is
* most promising to explore
* Note: No Thread Safe
*
* @param source source node
* @param target target node
* @returns a struct with a vector of Nodes if target is reachable else ERROR in case
* if target is not reachable or there is an error in the computation.
*
*/
virtual BestFirstSearchResult<T> best_first_search(const Node<T> &source, const Node<T> &target) const;
/**
* \brief
* Function performs the breadth first search algorithm over the graph
Expand Down Expand Up @@ -1587,6 +1601,81 @@ namespace CXXGRAPH
return result;
}

template <typename T>
BestFirstSearchResult<T> Graph<T>::best_first_search(const Node<T> &source, const Node<T> &target) const
{
BestFirstSearchResult<T> result;
auto &nodeSet = Graph<T>::getNodeSet();
using pq_type = std::pair<double, const Node<T> *>;

if(std::find(nodeSet.begin(), nodeSet.end(), &source) == nodeSet.end())
{
result.errorMessage = ERR_SOURCE_NODE_NOT_IN_GRAPH;
return result;
}

if(std::find(nodeSet.begin(), nodeSet.end(), &target) == nodeSet.end())
{
result.errorMessage = ERR_TARGET_NODE_NOT_IN_GRAPH;
return result;
}

auto adj = Graph<T>::getAdjMatrix();
std::priority_queue<pq_type, std::vector<pq_type>, std::greater<pq_type>> pq;

std::vector<Node<T>> visited;
visited.push_back(source);
pq.push(std::make_pair(static_cast<T>(0), &source));

while (!pq.empty())
{
const Node<T> *currentNode = pq.top().second;
pq.pop();
result.nodesInBestSearchOrder.push_back(*currentNode);

if (*currentNode == target)
{
break;
}
if (adj.find(currentNode) != adj.end())
{
for (const auto &elem : adj.at(currentNode))
{
if (elem.second->isWeighted().has_value())
{
if (elem.second->isDirected().has_value())
{
const DirectedWeightedEdge<T> *dw_edge = static_cast<const DirectedWeightedEdge<T> *>(elem.second);
if (std::find(visited.begin(), visited.end(), *(elem.first)) == visited.end())
{
visited.push_back(*(elem.first));
pq.push(std::make_pair(dw_edge->getWeight(), elem.first));
}
}
else
{
const UndirectedWeightedEdge<T> *dw_edge = static_cast<const UndirectedWeightedEdge<T> *>(elem.second);
if (std::find(visited.begin(), visited.end(), *(elem.first)) == visited.end())
{
visited.push_back(*(elem.first));
pq.push(std::make_pair(dw_edge->getWeight(), elem.first));
}
}
}
else
{
result.errorMessage = ERR_NO_WEIGHTED_EDGE;
result.nodesInBestSearchOrder.clear();
return result;
}
}
}
}

result.success = true;
return result;
}

template <typename T>
const std::vector<Node<T>> Graph<T>::breadth_first_search(const Node<T> &start) const
{
Expand Down
11 changes: 11 additions & 0 deletions include/Utility/Typedef.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ namespace CXXGRAPH
template <typename T>
using TopoSortResult = TopoSortResult_struct<T>;

/// Struct that contains the information about Best First Search Algorithm results
template <typename T>
struct BestFirstSearchResult_struct
{
bool success = false;
std::string errorMessage = "";
std::vector<Node<T>> nodesInBestSearchOrder = {};
};
template <typename T>
using BestFirstSearchResult = BestFirstSearchResult_struct<T>;

/// Struct that contains the information about the partitioning statistics


Expand Down
136 changes: 136 additions & 0 deletions test/BestFirstSearchTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "gtest/gtest.h"
#include "CXXGraph.hpp"
#include <vector>

TEST(BestFirstSearchTest, source_node_missing)
{
CXXGRAPH::Node<int> node1("1", 1);
CXXGRAPH::Node<int> node2("2", 2);
CXXGRAPH::Node<int> node3("3", 3);
CXXGRAPH::Node<int> node4("4", 4);
std::pair<const CXXGRAPH::Node<int> *, const CXXGRAPH::Node<int> *> pairNode(&node1, &node2);
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, pairNode, 1);
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node2, node3, 3);
CXXGRAPH::DirectedWeightedEdge<int> edge3(3, node1, node3, 6);
CXXGRAPH::T_EdgeSet<int> edgeSet;
edgeSet.insert(&edge1);
edgeSet.insert(&edge2);
edgeSet.insert(&edge3);
CXXGRAPH::Graph<int> graph(edgeSet);
CXXGRAPH::BestFirstSearchResult<int> res = graph.best_first_search(node4, node1);
ASSERT_EQ(res.success, false);
ASSERT_EQ(res.nodesInBestSearchOrder.size(), 0);
ASSERT_EQ(res.errorMessage, CXXGRAPH::ERR_SOURCE_NODE_NOT_IN_GRAPH);
}

TEST(BestFirstSearchTest, target_node_missing)
{
CXXGRAPH::Node<int> node1("1", 1);
CXXGRAPH::Node<int> node2("2", 2);
CXXGRAPH::Node<int> node3("3", 3);
CXXGRAPH::Node<int> node4("4", 4);
std::pair<const CXXGRAPH::Node<int> *, const CXXGRAPH::Node<int> *> pairNode(&node1, &node2);
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, pairNode, 1);
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node2, node3, 3);
CXXGRAPH::DirectedWeightedEdge<int> edge3(3, node1, node3, 6);
CXXGRAPH::T_EdgeSet<int> edgeSet;
edgeSet.insert(&edge1);
edgeSet.insert(&edge2);
edgeSet.insert(&edge3);
CXXGRAPH::Graph<int> graph(edgeSet);
CXXGRAPH::BestFirstSearchResult<int> res = graph.best_first_search(node1, node4);
ASSERT_EQ(res.success, false);
ASSERT_EQ(res.nodesInBestSearchOrder.size(), 0);
ASSERT_EQ(res.errorMessage, CXXGRAPH::ERR_TARGET_NODE_NOT_IN_GRAPH);
}

TEST(BestFirstSearchTest, correct_example_small)
{
CXXGRAPH::Node<int> node1("1", 1);
CXXGRAPH::Node<int> node2("2", 2);
CXXGRAPH::Node<int> node3("3", 3);
std::pair<const CXXGRAPH::Node<int> *, const CXXGRAPH::Node<int> *> pairNode(&node1, &node2);
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, pairNode, 1);
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node2, node3, 3);
CXXGRAPH::DirectedWeightedEdge<int> edge3(3, node1, node3, 6);
CXXGRAPH::T_EdgeSet<int> edgeSet;
edgeSet.insert(&edge1);
edgeSet.insert(&edge2);
edgeSet.insert(&edge3);
CXXGRAPH::Graph<int> graph(edgeSet);
CXXGRAPH::BestFirstSearchResult<int> res = graph.best_first_search(node1, node2);
ASSERT_EQ(res.success, true);
ASSERT_EQ(res.nodesInBestSearchOrder.size(), 2);
ASSERT_EQ(res.errorMessage, "");
}

TEST(BestFirstSearchTest, source_target_same)
{
CXXGRAPH::Node<int> node1("1", 1);
CXXGRAPH::Node<int> node2("2", 2);
CXXGRAPH::Node<int> node3("3", 3);
std::pair<const CXXGRAPH::Node<int> *, const CXXGRAPH::Node<int> *> pairNode(&node1, &node2);
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, pairNode, 1);
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node2, node3, 3);
CXXGRAPH::DirectedWeightedEdge<int> edge3(3, node1, node3, 6);
CXXGRAPH::T_EdgeSet<int> edgeSet;
edgeSet.insert(&edge1);
edgeSet.insert(&edge2);
edgeSet.insert(&edge3);
CXXGRAPH::Graph<int> graph(edgeSet);
CXXGRAPH::BestFirstSearchResult<int> res = graph.best_first_search(node1, node1);
ASSERT_EQ(res.success, true);
ASSERT_EQ(res.nodesInBestSearchOrder.size(), 1);
ASSERT_EQ(res.errorMessage, "");
}

TEST(BestFirstSearchTest, correct_example_big)
{
CXXGRAPH::Node<int> node1("1", 1);
CXXGRAPH::Node<int> node2("2", 2);
CXXGRAPH::Node<int> node3("3", 3);
CXXGRAPH::Node<int> node4("4", 4);
CXXGRAPH::Node<int> node5("5", 5);
CXXGRAPH::Node<int> node6("6", 6);
CXXGRAPH::Node<int> node7("7", 7);
CXXGRAPH::Node<int> node8("8", 8);
CXXGRAPH::Node<int> node9("9", 9);
CXXGRAPH::Node<int> node10("10", 10);
CXXGRAPH::Node<int> node11("11", 11);
CXXGRAPH::Node<int> node12("12", 12);
CXXGRAPH::Node<int> node13("13", 13);
CXXGRAPH::Node<int> node14("14", 14);
std::pair<const CXXGRAPH::Node<int> *, const CXXGRAPH::Node<int> *> pairNode(&node1, &node2);
CXXGRAPH::DirectedWeightedEdge<int> edge1(1, node1, node2, 3);
CXXGRAPH::DirectedWeightedEdge<int> edge2(2, node1, node3, 6);
CXXGRAPH::DirectedWeightedEdge<int> edge3(3, node1, node4, 5);
CXXGRAPH::DirectedWeightedEdge<int> edge4(4, node2, node5, 9);
CXXGRAPH::DirectedWeightedEdge<int> edge5(5, node2, node6, 8);
CXXGRAPH::DirectedWeightedEdge<int> edge6(6, node3, node7, 12);
CXXGRAPH::DirectedWeightedEdge<int> edge7(7, node3, node8, 14);
CXXGRAPH::DirectedWeightedEdge<int> edge8(8, node4, node9, 7);
CXXGRAPH::DirectedWeightedEdge<int> edge9(9, node9, node10, 5);
CXXGRAPH::DirectedWeightedEdge<int> edge10(10, node9, node11, 6);
CXXGRAPH::DirectedWeightedEdge<int> edge11(10, node10, node12, 1);
CXXGRAPH::DirectedWeightedEdge<int> edge12(10, node10, node13, 10);
CXXGRAPH::DirectedWeightedEdge<int> edge13(10, node10, node14, 2);
CXXGRAPH::T_EdgeSet<int> edgeSet;
edgeSet.insert(&edge1);
edgeSet.insert(&edge2);
edgeSet.insert(&edge3);
edgeSet.insert(&edge4);
edgeSet.insert(&edge5);
edgeSet.insert(&edge6);
edgeSet.insert(&edge7);
edgeSet.insert(&edge8);
edgeSet.insert(&edge9);
edgeSet.insert(&edge10);
edgeSet.insert(&edge11);
edgeSet.insert(&edge12);
edgeSet.insert(&edge13);
CXXGRAPH::Graph<int> graph(edgeSet);
CXXGRAPH::BestFirstSearchResult<int> res = graph.best_first_search(node1, node10);
ASSERT_EQ(res.success, true);
ASSERT_EQ(res.nodesInBestSearchOrder.size(), 6);
ASSERT_EQ(res.errorMessage, "");
}

0 comments on commit 5828120

Please sign in to comment.