Skip to content

Commit fb04740

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a tree of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) tree of operations and provides a mechanism to create a visitor using std::visit. This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent 79f468d commit fb04740

File tree

3 files changed

+205
-0
lines changed

3 files changed

+205
-0
lines changed

lib/Utils/ArithmeticTree.h

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#ifndef LIB_UTILS_ARITHMETICTREE_H_
2+
#define LIB_UTILS_ARITHMETICTREE_H_
3+
4+
#include <cstddef>
5+
#include <memory>
6+
#include <variant>
7+
8+
namespace mlir {
9+
namespace heir {
10+
11+
// This file contains a generic tree structure that can be used for representing
12+
// arithmetic trees with leaf nodes of various types.
13+
template <typename T>
14+
struct ArithmeticTreeNode;
15+
16+
// A leaf node for the tree
17+
template <typename T>
18+
struct LeafNode {
19+
T value;
20+
};
21+
22+
using ConstantNode = LeafNode<double>;
23+
24+
template <typename T>
25+
struct AddNode {
26+
std::unique_ptr<ArithmeticTreeNode<T>> left;
27+
std::unique_ptr<ArithmeticTreeNode<T>> right;
28+
};
29+
30+
template <typename T>
31+
struct MultiplyNode {
32+
std::unique_ptr<ArithmeticTreeNode<T>> left;
33+
std::unique_ptr<ArithmeticTreeNode<T>> right;
34+
};
35+
36+
template <typename T>
37+
struct PowerNode {
38+
std::unique_ptr<ArithmeticTreeNode<T>> base;
39+
size_t exponent;
40+
};
41+
42+
template <typename T>
43+
struct ArithmeticTreeNode {
44+
public:
45+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
46+
PowerNode<T>>
47+
node_variant;
48+
49+
explicit ArithmeticTreeNode(double constant)
50+
: node_variant(ConstantNode{constant}) {}
51+
explicit ArithmeticTreeNode(const T& value)
52+
: node_variant(LeafNode<T>{value}) {}
53+
explicit ArithmeticTreeNode(T&& value)
54+
: node_variant(LeafNode<T>{std::move(value)}) {}
55+
56+
private:
57+
ArithmeticTreeNode() = default;
58+
59+
public:
60+
// Static factory methods
61+
static std::unique_ptr<ArithmeticTreeNode<T>> constant(double constant) {
62+
return std::unique_ptr<ArithmeticTreeNode<T>>(
63+
new ArithmeticTreeNode<T>(constant));
64+
}
65+
66+
static std::unique_ptr<ArithmeticTreeNode<T>> leaf(const T& value) {
67+
return std::unique_ptr<ArithmeticTreeNode<T>>(
68+
new ArithmeticTreeNode<T>(value));
69+
}
70+
71+
static std::unique_ptr<ArithmeticTreeNode<T>> add(
72+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
73+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
74+
assert(lhs && rhs && "invalid add");
75+
auto node =
76+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
77+
// Note, to satisfy variant we need to use aggregate initialization inside
78+
// emplace
79+
node->node_variant.template emplace<AddNode<T>>(
80+
AddNode<T>{std::move(lhs), std::move(rhs)});
81+
return node;
82+
}
83+
84+
static std::unique_ptr<ArithmeticTreeNode<T>> mul(
85+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
86+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
87+
assert(lhs && rhs && "invalid mul");
88+
auto node =
89+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
90+
node->node_variant.template emplace<MultiplyNode<T>>(
91+
MultiplyNode<T>{std::move(lhs), std::move(rhs)});
92+
return node;
93+
}
94+
95+
static std::unique_ptr<ArithmeticTreeNode<T>> power(
96+
std::unique_ptr<ArithmeticTreeNode<T>> base, size_t exponent) {
97+
assert(base && "invalid base for power");
98+
auto node =
99+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
100+
node->node_variant.template emplace<PowerNode<T>>(
101+
PowerNode<T>{std::move(base), exponent});
102+
return node;
103+
}
104+
// The presence of std::unique_ptr in AddNode, MultiplyNode, and PowerNode
105+
// makes these types non-copyable. Consequently, ArithmeticTreeNode
106+
// itself is move-only by default.
107+
ArithmeticTreeNode(const ArithmeticTreeNode&) = delete; // No copying
108+
ArithmeticTreeNode& operator=(const ArithmeticTreeNode&) =
109+
delete; // No copy assignment
110+
111+
ArithmeticTreeNode(ArithmeticTreeNode&& other) noexcept =
112+
default; // Allow move construction
113+
ArithmeticTreeNode& operator=(ArithmeticTreeNode&& other) noexcept =
114+
default; // Allow move assignment
115+
116+
// Visitor pattern
117+
template <typename VisitorFunc>
118+
decltype(auto) visit(VisitorFunc&& visitor) {
119+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
120+
}
121+
122+
template <typename VisitorFunc>
123+
decltype(auto) visit(VisitorFunc&& visitor) const {
124+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
125+
}
126+
};
127+
128+
} // namespace heir
129+
} // namespace mlir
130+
131+
#endif // LIB_UTILS_ARITHMETICTREE_H_

lib/Utils/ArithmeticTreeTest.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <iomanip>
2+
#include <ios>
3+
#include <sstream>
4+
#include <string>
5+
6+
#include "gtest/gtest.h" // from @googletest
7+
#include "lib/Utils/ArithmeticTree.h"
8+
9+
namespace mlir {
10+
namespace heir {
11+
namespace {
12+
13+
using StringLeavedTree = ArithmeticTreeNode<std::string>;
14+
15+
struct FlattenedStringVisitor {
16+
std::string operator()(const ConstantNode& node) const {
17+
std::stringstream ss;
18+
ss << std::fixed << std::setprecision(2) << node.value;
19+
return ss.str();
20+
}
21+
22+
std::string operator()(const LeafNode<std::string>& node) const {
23+
return node.value;
24+
}
25+
26+
std::string operator()(const AddNode<std::string>& node) const {
27+
std::stringstream ss;
28+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
29+
<< ")";
30+
return ss.str();
31+
}
32+
33+
std::string operator()(const MultiplyNode<std::string>& node) const {
34+
std::stringstream ss;
35+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
36+
return ss.str();
37+
}
38+
39+
std::string operator()(const PowerNode<std::string>& node) const {
40+
std::stringstream ss;
41+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
42+
return ss.str();
43+
}
44+
};
45+
46+
TEST(ArithmeticTreeTest, TestPrint) {
47+
auto root = StringLeavedTree::mul(
48+
StringLeavedTree::add(StringLeavedTree::leaf("x"),
49+
StringLeavedTree::constant(3.0)),
50+
StringLeavedTree::power(StringLeavedTree::leaf("y"), 2));
51+
52+
FlattenedStringVisitor visitor;
53+
std::string result = root->visit(visitor);
54+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
55+
}
56+
57+
} // namespace
58+
} // namespace heir
59+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticTree",
159+
srcs = ["ArithmeticTree.h"],
160+
hdrs = ["ArithmeticTree.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticTreeTest",
165+
srcs = ["ArithmeticTreeTest.cpp"],
166+
deps = [
167+
":ArithmeticTree",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)