Skip to content

Commit 6f9fc9d

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a tree of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) tree of operations and provides a mechanism to create a visitor using std::visit. This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent b0cf72d commit 6f9fc9d

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

lib/Utils/ArithmeticTree.h

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#ifndef THIRD_PARTY_HEIR_LIB_UTILS_ARITHMETICTREE_H_
2+
#define THIRD_PARTY_HEIR_LIB_UTILS_ARITHMETICTREE_H_
3+
4+
#include <memory>
5+
#include <variant>
6+
7+
namespace mlir {
8+
namespace heir {
9+
10+
// This file contains a generic tree structure that can be used for representing
11+
// arithmetic trees with leaf nodes of various types.
12+
template <typename T>
13+
struct ArithmeticTreeNode;
14+
15+
// A leaf node for the tree
16+
template <typename T>
17+
struct LeafNode {
18+
T value;
19+
};
20+
21+
using ConstantNode = LeafNode<double>;
22+
23+
template <typename T>
24+
struct AddNode {
25+
std::unique_ptr<ArithmeticTreeNode<T>> left;
26+
std::unique_ptr<ArithmeticTreeNode<T>> right;
27+
};
28+
29+
template <typename T>
30+
struct MultiplyNode {
31+
std::unique_ptr<ArithmeticTreeNode<T>> left;
32+
std::unique_ptr<ArithmeticTreeNode<T>> right;
33+
};
34+
35+
template <typename T>
36+
struct PowerNode {
37+
std::unique_ptr<ArithmeticTreeNode<T>> base;
38+
size_t exponent;
39+
};
40+
41+
template <typename T>
42+
struct ArithmeticTreeNode {
43+
public:
44+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
45+
PowerNode<T>>
46+
node_variant;
47+
48+
explicit ArithmeticTreeNode(double constant)
49+
: node_variant(ConstantNode{constant}) {}
50+
explicit ArithmeticTreeNode(const T& value)
51+
: node_variant(LeafNode<T>{value}) {}
52+
explicit ArithmeticTreeNode(T&& value)
53+
: node_variant(LeafNode<T>{std::move(value)}) {}
54+
55+
private:
56+
ArithmeticTreeNode() = default;
57+
58+
public:
59+
// Static factory methods
60+
static std::unique_ptr<ArithmeticTreeNode<T>> constant(double constant) {
61+
return std::unique_ptr<ArithmeticTreeNode<T>>(
62+
new ArithmeticTreeNode<T>(constant));
63+
}
64+
65+
static std::unique_ptr<ArithmeticTreeNode<T>> leaf(const T& value) {
66+
return std::unique_ptr<ArithmeticTreeNode<T>>(
67+
new ArithmeticTreeNode<T>(value));
68+
}
69+
70+
static std::unique_ptr<ArithmeticTreeNode<T>> add(
71+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
72+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
73+
assert(lhs && rhs && "invalid add");
74+
auto node =
75+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
76+
node->node_variant.template emplace<AddNode<T>>(
77+
AddNode{std::move(lhs), std::move(rhs)});
78+
return node;
79+
}
80+
81+
static std::unique_ptr<ArithmeticTreeNode<T>> mul(
82+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
83+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
84+
assert(lhs && rhs && "invalid mul");
85+
auto node =
86+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
87+
node->node_variant.template emplace<MultiplyNode<T>>(
88+
MultiplyNode{std::move(lhs), std::move(rhs)});
89+
return node;
90+
}
91+
92+
static std::unique_ptr<ArithmeticTreeNode<T>> power(
93+
std::unique_ptr<ArithmeticTreeNode<T>> base, size_t exponent) {
94+
assert(base && "invalid base for power");
95+
auto node =
96+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
97+
node->node_variant.template emplace<PowerNode<T>>(
98+
PowerNode<T>{std::move(base), exponent});
99+
return node;
100+
}
101+
102+
// The presence of std::unique_ptr in AddNode, MultiplyNode, and PowerNode
103+
// makes these types non-copyable. Consequently, ArithmeticTreeNode
104+
// itself is move-only by default.
105+
ArithmeticTreeNode(const ArithmeticTreeNode&) = delete; // No copying
106+
ArithmeticTreeNode& operator=(const ArithmeticTreeNode&) =
107+
delete; // No copy assignment
108+
109+
ArithmeticTreeNode(ArithmeticTreeNode&& other) noexcept =
110+
default; // Allow move construction
111+
ArithmeticTreeNode& operator=(ArithmeticTreeNode&& other) noexcept =
112+
default; // Allow move assignment
113+
114+
// Visitor pattern
115+
template <typename VisitorFunc>
116+
decltype(auto) visit(VisitorFunc&& visitor) {
117+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
118+
}
119+
120+
template <typename VisitorFunc>
121+
decltype(auto) visit(VisitorFunc&& visitor) const {
122+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
123+
}
124+
};
125+
126+
} // namespace heir
127+
} // namespace mlir
128+
129+
#endif // THIRD_PARTY_HEIR_LIB_UTILS_ARITHMETICTREE_H_

lib/Utils/ArithmeticTreeTest.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "gmock/gmock.h" // from @googletest
2+
#include "gtest/gtest.h" // from @googletest
3+
#include "lib/Utils/ArithmeticTree.h"
4+
5+
namespace mlir {
6+
namespace heir {
7+
namespace {
8+
9+
using StringLeavedTree = ArithmeticTreeNode<std::string>;
10+
11+
struct FlattenedStringVisitor {
12+
std::string operator()(const ConstantNode& node) const {
13+
std::stringstream ss;
14+
ss << std::fixed << std::setprecision(2) << node.value;
15+
return ss.str();
16+
}
17+
18+
std::string operator()(const LeafNode<std::string>& node) const {
19+
return node.value;
20+
}
21+
22+
std::string operator()(const AddNode<std::string>& node) const {
23+
std::stringstream ss;
24+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
25+
<< ")";
26+
return ss.str();
27+
}
28+
29+
std::string operator()(const MultiplyNode<std::string>& node) const {
30+
std::stringstream ss;
31+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
32+
return ss.str();
33+
}
34+
35+
std::string operator()(const PowerNode<std::string>& node) const {
36+
std::stringstream ss;
37+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
38+
return ss.str();
39+
}
40+
};
41+
42+
TEST(ArithmeticTreeTest, TestPrint) {
43+
auto root = StringLeavedTree::mul(
44+
StringLeavedTree::add(StringLeavedTree::leaf("x"),
45+
StringLeavedTree::constant(3.0)),
46+
StringLeavedTree::power(StringLeavedTree::leaf("y"), 2));
47+
48+
FlattenedStringVisitor visitor;
49+
std::string result = root->visit(visitor);
50+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
51+
}
52+
53+
} // namespace
54+
} // namespace heir
55+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticTree",
159+
srcs = ["ArithmeticTree.h"],
160+
hdrs = ["ArithmeticTree.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticTreeTest",
165+
srcs = ["ArithmeticTreeTest.cpp"],
166+
deps = [
167+
":ArithmeticTree",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)