Skip to content

Commit d951434

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a DAG of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) DAG of operations and provides a mechanism to create a visitor using std::visit (with a base class for a caching visitor). This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent 0fe389a commit d951434

File tree

3 files changed

+317
-0
lines changed

3 files changed

+317
-0
lines changed

lib/Utils/ArithmeticDag.h

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#ifndef LIB_UTILS_ARITHMETICDAG_H_
2+
#define LIB_UTILS_ARITHMETICDAG_H_
3+
4+
#include <cassert>
5+
#include <cstddef>
6+
#include <memory>
7+
#include <unordered_map>
8+
#include <utility>
9+
#include <variant>
10+
11+
namespace mlir {
12+
namespace heir {
13+
14+
// This file contains a generic DAG structure that can be used for representing
15+
// arithmetic DAGs with leaf nodes of various types.
16+
template <typename T>
17+
struct ArithmeticDagNode;
18+
19+
// A leaf node for the DAG
20+
template <typename T>
21+
struct LeafNode {
22+
T value;
23+
};
24+
25+
struct ConstantNode {
26+
double value;
27+
};
28+
29+
template <typename T>
30+
struct AddNode {
31+
std::shared_ptr<ArithmeticDagNode<T>> left;
32+
std::shared_ptr<ArithmeticDagNode<T>> right;
33+
};
34+
35+
template <typename T>
36+
struct MultiplyNode {
37+
std::shared_ptr<ArithmeticDagNode<T>> left;
38+
std::shared_ptr<ArithmeticDagNode<T>> right;
39+
};
40+
41+
template <typename T>
42+
struct PowerNode {
43+
std::shared_ptr<ArithmeticDagNode<T>> base;
44+
size_t exponent;
45+
};
46+
47+
template <typename T>
48+
struct ArithmeticDagNode {
49+
public:
50+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
51+
PowerNode<T>>
52+
node_variant;
53+
54+
explicit ArithmeticDagNode(const T& value)
55+
: node_variant(LeafNode<T>{value}) {}
56+
explicit ArithmeticDagNode(T&& value)
57+
: node_variant(LeafNode<T>{std::move(value)}) {}
58+
59+
private:
60+
ArithmeticDagNode() = default;
61+
62+
public:
63+
// Static factory methods
64+
static std::shared_ptr<ArithmeticDagNode<T>> leaf(const T& value) {
65+
// This factory method differs from the others because T may not have a
66+
// default constructor to use with emplace. In that case, we need to rely
67+
// on the move or copy constructors, which corresponds to the two
68+
// ArithmeticDagNode constructors above.
69+
return std::shared_ptr<ArithmeticDagNode<T>>(
70+
new ArithmeticDagNode<T>(value));
71+
}
72+
73+
static std::shared_ptr<ArithmeticDagNode<T>> constant(double constant) {
74+
auto node =
75+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
76+
// Note, to satisfy variant we need to use aggregate initialization inside
77+
// emplace
78+
node->node_variant.template emplace<ConstantNode>(ConstantNode{constant});
79+
return node;
80+
}
81+
82+
static std::shared_ptr<ArithmeticDagNode<T>> add(
83+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
84+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
85+
assert(lhs && rhs && "invalid add");
86+
auto node =
87+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
88+
node->node_variant.template emplace<AddNode<T>>(
89+
AddNode<T>{std::move(lhs), std::move(rhs)});
90+
return node;
91+
}
92+
93+
static std::shared_ptr<ArithmeticDagNode<T>> mul(
94+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
95+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
96+
assert(lhs && rhs && "invalid mul");
97+
auto node =
98+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
99+
node->node_variant.template emplace<MultiplyNode<T>>(
100+
MultiplyNode<T>{std::move(lhs), std::move(rhs)});
101+
return node;
102+
}
103+
104+
static std::shared_ptr<ArithmeticDagNode<T>> power(
105+
std::shared_ptr<ArithmeticDagNode<T>> base, size_t exponent) {
106+
assert(base && "invalid base for power");
107+
auto node =
108+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
109+
node->node_variant.template emplace<PowerNode<T>>(
110+
PowerNode<T>{std::move(base), exponent});
111+
return node;
112+
}
113+
114+
ArithmeticDagNode(const ArithmeticDagNode&) = default;
115+
ArithmeticDagNode& operator=(const ArithmeticDagNode&) = default;
116+
ArithmeticDagNode(ArithmeticDagNode&&) noexcept = default;
117+
ArithmeticDagNode& operator=(ArithmeticDagNode&&) noexcept = default;
118+
119+
// Visitor pattern
120+
template <typename VisitorFunc>
121+
decltype(auto) visit(VisitorFunc&& visitor) {
122+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
123+
}
124+
125+
template <typename VisitorFunc>
126+
decltype(auto) visit(VisitorFunc&& visitor) const {
127+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
128+
}
129+
};
130+
131+
/// A base class for visitors that caches intermediate results.
132+
///
133+
/// Template parameters:
134+
/// T: The type of the leaf nodes.
135+
/// ResultType: The type of the result of the visit.
136+
template <typename T, typename ResultType>
137+
class CachingVisitor {
138+
public:
139+
virtual ~CachingVisitor() = default;
140+
141+
/// The main entry point that contains the caching logic.
142+
ResultType process(const std::shared_ptr<ArithmeticDagNode<T>>& node) {
143+
assert(node != nullptr && "invalid null node!");
144+
145+
const auto* node_ptr = node.get();
146+
if (auto it = cache.find(node_ptr); it != cache.end()) {
147+
return it->second;
148+
}
149+
150+
ResultType result = std::visit(*this, node->node_variant);
151+
cache[node_ptr] = result;
152+
return result;
153+
}
154+
155+
// --- Virtual Visit Methods ---
156+
// Derived classes must override these for the node types they support.
157+
158+
virtual ResultType operator()(const ConstantNode& node) {
159+
assert(false && "Visit logic for ConstantNode is not implemented.");
160+
}
161+
162+
virtual ResultType operator()(const LeafNode<T>& node) {
163+
assert(false && "Visit logic for LeafNode is not implemented.");
164+
}
165+
166+
virtual ResultType operator()(const AddNode<T>& node) {
167+
assert(false && "Visit logic for AddNode is not implemented.");
168+
}
169+
170+
virtual ResultType operator()(const MultiplyNode<T>& node) {
171+
assert(false && "Visit logic for MultiplyNode is not implemented.");
172+
}
173+
174+
virtual ResultType operator()(const PowerNode<T>& node) {
175+
assert(false && "Visit logic for PowerNode is not implemented.");
176+
}
177+
178+
private:
179+
std::unordered_map<const ArithmeticDagNode<T>*, ResultType> cache;
180+
};
181+
182+
} // namespace heir
183+
} // namespace mlir
184+
185+
#endif // LIB_UTILS_ARITHMETICDAG_H_

lib/Utils/ArithmeticDagTest.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <cmath>
2+
#include <iomanip>
3+
#include <ios>
4+
#include <sstream>
5+
#include <string>
6+
7+
#include "gtest/gtest.h" // from @googletest
8+
#include "lib/Utils/ArithmeticDag.h"
9+
10+
namespace mlir {
11+
namespace heir {
12+
namespace {
13+
14+
using StringLeavedDag = ArithmeticDagNode<std::string>;
15+
using DoubleLeavedDag = ArithmeticDagNode<double>;
16+
17+
struct FlattenedStringVisitor {
18+
std::string operator()(const ConstantNode& node) const {
19+
std::stringstream ss;
20+
ss << std::fixed << std::setprecision(2) << node.value;
21+
return ss.str();
22+
}
23+
24+
std::string operator()(const LeafNode<std::string>& node) const {
25+
return node.value;
26+
}
27+
28+
std::string operator()(const AddNode<std::string>& node) const {
29+
std::stringstream ss;
30+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
31+
<< ")";
32+
return ss.str();
33+
}
34+
35+
std::string operator()(const MultiplyNode<std::string>& node) const {
36+
std::stringstream ss;
37+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
38+
return ss.str();
39+
}
40+
41+
std::string operator()(const PowerNode<std::string>& node) const {
42+
std::stringstream ss;
43+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
44+
return ss.str();
45+
}
46+
};
47+
48+
class EvalVisitor : public CachingVisitor<double, double> {
49+
public:
50+
EvalVisitor() : CachingVisitor<double, double>(), callCount(0) {}
51+
52+
// To test that caching works as expected.
53+
int callCount;
54+
55+
double operator()(const ConstantNode& node) override {
56+
callCount += 1;
57+
return node.value;
58+
}
59+
60+
double operator()(const LeafNode<double>& node) override {
61+
callCount += 1;
62+
return node.value;
63+
}
64+
65+
double operator()(const AddNode<double>& node) override {
66+
// Recursive calls use the public `process` method from the base class
67+
// to ensure caching is applied at every step.
68+
callCount += 1;
69+
return this->process(node.left) + this->process(node.right);
70+
}
71+
72+
double operator()(const MultiplyNode<double>& node) override {
73+
callCount += 1;
74+
return this->process(node.left) * this->process(node.right);
75+
}
76+
77+
double operator()(const PowerNode<double>& node) override {
78+
callCount += 1;
79+
return std::pow(this->process(node.base), node.exponent);
80+
}
81+
};
82+
83+
TEST(ArithmeticDagTest, TestPrint) {
84+
auto root = StringLeavedDag::mul(
85+
StringLeavedDag::add(StringLeavedDag::leaf("x"),
86+
StringLeavedDag::constant(3.0)),
87+
StringLeavedDag::power(StringLeavedDag::leaf("y"), 2));
88+
89+
FlattenedStringVisitor visitor;
90+
std::string result = root->visit(visitor);
91+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
92+
}
93+
94+
TEST(ArithmeticDagTest, TestProperDag) {
95+
auto shared = StringLeavedDag::power(StringLeavedDag::leaf("y"), 2);
96+
auto root =
97+
StringLeavedDag::mul(StringLeavedDag::add(shared, shared), shared);
98+
99+
FlattenedStringVisitor visitor;
100+
std::string result = root->visit(visitor);
101+
EXPECT_EQ(result, "((y ^ 2) + (y ^ 2)) * (y ^ 2)");
102+
}
103+
104+
TEST(ArithmeticDagTest, TestEvaluationVisitor) {
105+
auto shared = DoubleLeavedDag::power(DoubleLeavedDag::leaf(2.0), 2);
106+
auto root = DoubleLeavedDag::mul(DoubleLeavedDag::add(shared, shared),
107+
DoubleLeavedDag::constant(3.0));
108+
109+
EvalVisitor visitor;
110+
double result = root->visit(visitor);
111+
EXPECT_EQ(result, 24.0);
112+
EXPECT_EQ(visitor.callCount, 5);
113+
}
114+
115+
} // namespace
116+
} // namespace heir
117+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticDag",
159+
srcs = ["ArithmeticDag.h"],
160+
hdrs = ["ArithmeticDag.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticDagTest",
165+
srcs = ["ArithmeticDagTest.cpp"],
166+
deps = [
167+
":ArithmeticDag",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)