Skip to content

Commit

Permalink
Add iterators on ASTs, allowing for loops, std::find/std::find_if, st…
Browse files Browse the repository at this point in the history
…d::count/std::count_if, etc. (#2387)

Co-authored-by: OMNES Florian <[email protected]>
  • Loading branch information
flomnes and OMNES Florian authored Sep 10, 2024
1 parent 24256a1 commit 775224e
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/solver/expressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ set(SRC_Expressions
include/antares/solver/expressions/IName.h
)


source_group("expressions" FILES ${SRC_Expressions})
add_library(antares-solver-expressions
${SRC_Expressions})
Expand All @@ -67,7 +66,17 @@ target_include_directories(antares-solver-expressions
target_link_libraries(antares-solver-expressions
PUBLIC
Antares::logs
)



add_library(antares-solver-expressions-iterators
iterators/pre-order.cpp
include/antares/solver/expressions/iterators/pre-order.h
)

target_link_libraries(antares-solver-expressions-iterators PRIVATE antares-solver-expressions)

install(DIRECTORY include/antares
DESTINATION "include"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include <stack>
#include <vector>

namespace Antares::Solver::Nodes
{
// Forward-declaration is enough

class Node;

// PreOrder Iterator for AST
class ASTPreOrderIterator
{
std::stack<Node*> nodeStack;

public:
// Iterator type aliases
using iterator_category = std::forward_iterator_tag;
using value_type = Node;
using difference_type = std::ptrdiff_t;
using pointer = Node*;
using reference = Node&;

// Constructor
explicit ASTPreOrderIterator(Node* root = nullptr);

// Dereference operator
reference operator*() const;

// Pointer access operator
pointer operator->() const;

// Increment operator (pre-order traversal)
ASTPreOrderIterator& operator++();

// Equality comparison
bool operator==(const ASTPreOrderIterator& other) const;

// Inequality comparison
bool operator!=(const ASTPreOrderIterator& other) const;
};

// AST container class to expose begin/end iterators
class AST
{
Node* root;

public:
explicit AST(Node* rootNode);

// Begin iterator
ASTPreOrderIterator begin();

// End iterator (indicating traversal is complete)
ASTPreOrderIterator end();
};
} // namespace Antares::Solver::Nodes
103 changes: 103 additions & 0 deletions src/solver/expressions/iterators/pre-order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <ranges>

#include <antares/solver/expressions/iterators/pre-order.h>
#include <antares/solver/expressions/nodes/ExpressionsNodes.h>

namespace Antares::Solver::Nodes
{
namespace
{
// Children, left to right
std::vector<Node*> childrenLeftToRight(Node* node)
{
if (auto* bin = dynamic_cast<BinaryNode*>(node))
{
return {bin->left(), bin->right()};
}
else if (auto* unary = dynamic_cast<UnaryNode*>(node))
{
return {unary->child()};
}
return {};
}
} // namespace

// Constructor
ASTPreOrderIterator::ASTPreOrderIterator(Node* root)
{
if (root)
{
nodeStack.push(root);
}
}

// Dereference operator
ASTPreOrderIterator::reference ASTPreOrderIterator::operator*() const
{
return *nodeStack.top();
}

// Pointer access operator
ASTPreOrderIterator::pointer ASTPreOrderIterator::operator->() const
{
return nodeStack.top();
}

// Increment operator (pre-order traversal)
ASTPreOrderIterator& ASTPreOrderIterator::operator++()
{
if (nodeStack.empty())
{
return *this;
}

Node* current = nodeStack.top();
nodeStack.pop();

const auto children = childrenLeftToRight(current);
// Push children in reverse order to process them in left-to-right order
for (auto* it: children | std::views::reverse)
{
nodeStack.push(it);
}

return *this;
}

// Equality comparison
bool ASTPreOrderIterator::operator==(const ASTPreOrderIterator& other) const
{
if (nodeStack.empty() && other.nodeStack.empty())
{
return true;
}
if (nodeStack.empty() || other.nodeStack.empty())
{
return false;
}
return nodeStack.top() == other.nodeStack.top();
}

// Inequality comparison
bool ASTPreOrderIterator::operator!=(const ASTPreOrderIterator& other) const
{
return !(*this == other);
}

AST::AST(Node* rootNode):
root(rootNode)
{
}

// Begin iterator
ASTPreOrderIterator AST::begin()
{
return ASTPreOrderIterator(root);
}

// End iterator (indicating traversal is complete)
ASTPreOrderIterator AST::end()
{
return ASTPreOrderIterator(nullptr);
}
} // namespace Antares::Solver::Nodes
2 changes: 2 additions & 0 deletions src/tests/src/solver/expressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ target_sources(${EXECUTABLE_NAME}
test_CompareVisitor.cpp
test_CloneVisitor.cpp
test_DeepWideTrees.cpp
test_Iterators.cpp
)

target_link_libraries(${EXECUTABLE_NAME}
PRIVATE
Boost::unit_test_framework
antares-solver-expressions
antares-solver-expressions-iterators
)

# Storing tests-ts-numbers under the folder Unit-tests in the IDE
Expand Down
136 changes: 136 additions & 0 deletions src/tests/src/solver/expressions/test_Iterators.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2007-2024, RTE (https://www.rte-france.com)
* See AUTHORS.txt
* SPDX-License-Identifier: MPL-2.0
* This file is part of Antares-Simulator,
* Adequacy and Performance assessment for interconnected energy networks.
*
* Antares_Simulator is free software: you can redistribute it and/or modify
* it under the terms of the Mozilla Public Licence 2.0 as published by
* the Mozilla Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* Antares_Simulator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Mozilla Public Licence 2.0 for more details.
*
* You should have received a copy of the Mozilla Public Licence 2.0
* along with Antares_Simulator. If not, see <https://opensource.org/license/mpl-2-0/>.
*/

#define WIN32_LEAN_AND_MEAN

#include <algorithm>

#include <boost/test/unit_test.hpp>

#include <antares/solver/expressions/Registry.hxx>
#include <antares/solver/expressions/iterators/pre-order.h>
#include <antares/solver/expressions/nodes/ExpressionsNodes.h>

using namespace Antares::Solver;
using namespace Antares::Solver::Nodes;

BOOST_AUTO_TEST_SUITE(_Iterator_)

static Node* simpleExpression(Registry<Node>& registry)
{
return registry.create<AddNode>(registry.create<LiteralNode>(2.),
registry.create<LiteralNode>(21.));
}

BOOST_FIXTURE_TEST_CASE(empty_ast_begin_is_end, Registry<Node>)
{
AST ast(nullptr);
BOOST_CHECK(ast.begin() == ast.end());
}

BOOST_FIXTURE_TEST_CASE(simple_end_is_end, Registry<Node>)
{
AST ast(create<LiteralNode>(32.));
BOOST_CHECK(ast.end() == ast.end());
}

BOOST_FIXTURE_TEST_CASE(dereference_op, Registry<Node>)
{
AST ast(create<LiteralNode>(21.));
auto it = ast.begin();
const std::string expected("LiteralNode");
BOOST_CHECK_EQUAL(it->name(), expected);
BOOST_CHECK_EQUAL((*it).name(), expected);
}

BOOST_FIXTURE_TEST_CASE(unary_dereference, Registry<Node>)
{
AST ast(create<NegationNode>(nullptr));
auto it = ast.begin();
BOOST_CHECK(!it->name().empty());
BOOST_CHECK(!(*it).name().empty());
}

BOOST_FIXTURE_TEST_CASE(count_literal_nodes_for_loop, Registry<Node>)
{
int count_lit = 0;
for (auto& node: AST(simpleExpression(*this)))
{
if (dynamic_cast<LiteralNode*>(&node))
{
count_lit++;
}
}
BOOST_CHECK_EQUAL(count_lit, 2);
}

BOOST_FIXTURE_TEST_CASE(count_literal_nodes_count_if, Registry<Node>)
{
AST ast(simpleExpression(*this));
int count_lit = std::count_if(ast.begin(),
ast.end(),
[](Node& node)
{ return dynamic_cast<LiteralNode*>(&node) != nullptr; });

BOOST_CHECK_EQUAL(count_lit, 2);
}

BOOST_FIXTURE_TEST_CASE(find_if_not_found, Registry<Node>)
{
AST ast(simpleExpression(*this));
auto it = std::find_if(ast.begin(),
ast.end(),
[](Node& node)
{ return dynamic_cast<MultiplicationNode*>(&node) != nullptr; });
BOOST_CHECK(it == ast.end());
}

BOOST_FIXTURE_TEST_CASE(find_if_found, Registry<Node>)
{
AST ast(simpleExpression(*this));
auto it = std::find_if(ast.begin(),
ast.end(),
[](Node& node) { return dynamic_cast<LiteralNode*>(&node) != nullptr; });
BOOST_CHECK(it != ast.end());
auto* res = dynamic_cast<LiteralNode*>(&*it);
BOOST_REQUIRE(res);
BOOST_CHECK_EQUAL(res->value(), 2.);
}

BOOST_FIXTURE_TEST_CASE(distance_is_3, Registry<Node>)
{
AST ast(simpleExpression(*this));
BOOST_CHECK_EQUAL(std::distance(ast.begin(), ast.end()), 3);
}

BOOST_FIXTURE_TEST_CASE(distance_unary, Registry<Node>)
{
AST ast(create<NegationNode>(create<NegationNode>(create<LiteralNode>(32.))));
BOOST_CHECK_EQUAL(std::distance(ast.begin(), ast.end()), 3);
}

BOOST_FIXTURE_TEST_CASE(distance_nullptr_is_3, Registry<Node>)
{
AST ast(create<AddNode>(nullptr, create<LiteralNode>(2.)));
BOOST_CHECK_EQUAL(std::distance(ast.begin(), ast.end()), 3);
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 775224e

Please sign in to comment.