Skip to content

Commit 80d79f3

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a DAG of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) DAG of operations and provides a mechanism to create a visitor using std::visit (with a base class for a caching visitor). This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent 0fe389a commit 80d79f3

File tree

3 files changed

+316
-0
lines changed

3 files changed

+316
-0
lines changed

lib/Utils/ArithmeticDag.h

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#ifndef LIB_UTILS_ARITHMETICDAG_H_
2+
#define LIB_UTILS_ARITHMETICDAG_H_
3+
4+
#include <cassert>
5+
#include <cstddef>
6+
#include <memory>
7+
#include <unordered_map>
8+
#include <utility>
9+
#include <variant>
10+
11+
namespace mlir {
12+
namespace heir {
13+
14+
// This file contains a generic DAG structure that can be used for representing
15+
// arithmetic DAGs with leaf nodes of various types.
16+
template <typename T>
17+
struct ArithmeticDagNode;
18+
19+
// A leaf node for the DAG
20+
template <typename T>
21+
struct LeafNode {
22+
T value;
23+
};
24+
25+
struct ConstantNode {
26+
double value;
27+
};
28+
29+
template <typename T>
30+
struct AddNode {
31+
std::shared_ptr<ArithmeticDagNode<T>> left;
32+
std::shared_ptr<ArithmeticDagNode<T>> right;
33+
};
34+
35+
template <typename T>
36+
struct MultiplyNode {
37+
std::shared_ptr<ArithmeticDagNode<T>> left;
38+
std::shared_ptr<ArithmeticDagNode<T>> right;
39+
};
40+
41+
template <typename T>
42+
struct PowerNode {
43+
std::shared_ptr<ArithmeticDagNode<T>> base;
44+
size_t exponent;
45+
};
46+
47+
template <typename T>
48+
struct ArithmeticDagNode {
49+
public:
50+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
51+
PowerNode<T>>
52+
node_variant;
53+
54+
explicit ArithmeticDagNode(const ConstantNode& node) : node_variant(node) {}
55+
explicit ArithmeticDagNode(ConstantNode&& node)
56+
: node_variant(std::move(node)) {}
57+
explicit ArithmeticDagNode(const T& value)
58+
: node_variant(LeafNode<T>{value}) {}
59+
explicit ArithmeticDagNode(T&& value)
60+
: node_variant(LeafNode<T>{std::move(value)}) {}
61+
62+
private:
63+
ArithmeticDagNode() = default;
64+
65+
public:
66+
// Static factory methods
67+
static std::shared_ptr<ArithmeticDagNode<T>> constant(double constant) {
68+
// Note: using std::shared_ptr(new...) because the constructor is private.
69+
// std::make_shared would require a public constructor.
70+
return std::shared_ptr<ArithmeticDagNode<T>>(
71+
new ArithmeticDagNode<T>(ConstantNode(constant)));
72+
}
73+
74+
static std::shared_ptr<ArithmeticDagNode<T>> leaf(const T& value) {
75+
return std::shared_ptr<ArithmeticDagNode<T>>(
76+
new ArithmeticDagNode<T>(value));
77+
}
78+
79+
static std::shared_ptr<ArithmeticDagNode<T>> add(
80+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
81+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
82+
assert(lhs && rhs && "invalid add");
83+
auto node =
84+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
85+
// Note, to satisfy variant we need to use aggregate initialization inside
86+
// emplace
87+
node->node_variant.template emplace<AddNode<T>>(
88+
AddNode<T>{std::move(lhs), std::move(rhs)});
89+
return node;
90+
}
91+
92+
static std::shared_ptr<ArithmeticDagNode<T>> mul(
93+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
94+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
95+
assert(lhs && rhs && "invalid mul");
96+
auto node =
97+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
98+
node->node_variant.template emplace<MultiplyNode<T>>(
99+
MultiplyNode<T>{std::move(lhs), std::move(rhs)});
100+
return node;
101+
}
102+
103+
static std::shared_ptr<ArithmeticDagNode<T>> power(
104+
std::shared_ptr<ArithmeticDagNode<T>> base, size_t exponent) {
105+
assert(base && "invalid base for power");
106+
auto node =
107+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
108+
node->node_variant.template emplace<PowerNode<T>>(
109+
PowerNode<T>{std::move(base), exponent});
110+
return node;
111+
}
112+
113+
ArithmeticDagNode(const ArithmeticDagNode&) = default;
114+
ArithmeticDagNode& operator=(const ArithmeticDagNode&) = default;
115+
ArithmeticDagNode(ArithmeticDagNode&&) noexcept = default;
116+
ArithmeticDagNode& operator=(ArithmeticDagNode&&) noexcept = default;
117+
118+
// Visitor pattern
119+
template <typename VisitorFunc>
120+
decltype(auto) visit(VisitorFunc&& visitor) {
121+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
122+
}
123+
124+
template <typename VisitorFunc>
125+
decltype(auto) visit(VisitorFunc&& visitor) const {
126+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
127+
}
128+
};
129+
130+
/// A base class for visitors that caches intermediate results.
131+
///
132+
/// Template parameters:
133+
/// T: The type of the leaf nodes.
134+
/// ResultType: The type of the result of the visit.
135+
template <typename T, typename ResultType>
136+
class CachingVisitor {
137+
public:
138+
virtual ~CachingVisitor() = default;
139+
140+
/// The main entry point that contains the caching logic.
141+
ResultType process(const std::shared_ptr<ArithmeticDagNode<T>>& node) {
142+
assert(node != nullptr && "invalid null node!");
143+
144+
const auto* node_ptr = node.get();
145+
if (auto it = cache.find(node_ptr); it != cache.end()) {
146+
return it->second;
147+
}
148+
149+
ResultType result = std::visit(*this, node->node_variant);
150+
cache[node_ptr] = result;
151+
return result;
152+
}
153+
154+
// --- Virtual Visit Methods ---
155+
// Derived classes must override these for the node types they support.
156+
157+
virtual ResultType operator()(const ConstantNode& node) {
158+
assert(false && "Visit logic for ConstantNode is not implemented.");
159+
}
160+
161+
virtual ResultType operator()(const LeafNode<T>& node) {
162+
assert(false && "Visit logic for LeafNode is not implemented.");
163+
}
164+
165+
virtual ResultType operator()(const AddNode<T>& node) {
166+
assert(false && "Visit logic for AddNode is not implemented.");
167+
}
168+
169+
virtual ResultType operator()(const MultiplyNode<T>& node) {
170+
assert(false && "Visit logic for MultiplyNode is not implemented.");
171+
}
172+
173+
virtual ResultType operator()(const PowerNode<T>& node) {
174+
assert(false && "Visit logic for PowerNode is not implemented.");
175+
}
176+
177+
private:
178+
std::unordered_map<const ArithmeticDagNode<T>*, ResultType> cache;
179+
};
180+
181+
} // namespace heir
182+
} // namespace mlir
183+
184+
#endif // LIB_UTILS_ARITHMETICDAG_H_

lib/Utils/ArithmeticDagTest.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <cmath>
2+
#include <iomanip>
3+
#include <ios>
4+
#include <sstream>
5+
#include <string>
6+
7+
#include "gtest/gtest.h" // from @googletest
8+
#include "lib/Utils/ArithmeticDag.h"
9+
10+
namespace mlir {
11+
namespace heir {
12+
namespace {
13+
14+
using StringLeavedDag = ArithmeticDagNode<std::string>;
15+
using DoubleLeavedDag = ArithmeticDagNode<double>;
16+
17+
struct FlattenedStringVisitor {
18+
std::string operator()(const ConstantNode& node) const {
19+
std::stringstream ss;
20+
ss << std::fixed << std::setprecision(2) << node.value;
21+
return ss.str();
22+
}
23+
24+
std::string operator()(const LeafNode<std::string>& node) const {
25+
return node.value;
26+
}
27+
28+
std::string operator()(const AddNode<std::string>& node) const {
29+
std::stringstream ss;
30+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
31+
<< ")";
32+
return ss.str();
33+
}
34+
35+
std::string operator()(const MultiplyNode<std::string>& node) const {
36+
std::stringstream ss;
37+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
38+
return ss.str();
39+
}
40+
41+
std::string operator()(const PowerNode<std::string>& node) const {
42+
std::stringstream ss;
43+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
44+
return ss.str();
45+
}
46+
};
47+
48+
class EvalVisitor : public CachingVisitor<double, double> {
49+
public:
50+
EvalVisitor() : CachingVisitor<double, double>(), callCount(0) {}
51+
52+
// To test that caching works as expected.
53+
int callCount;
54+
55+
double operator()(const ConstantNode& node) override {
56+
callCount += 1;
57+
return node.value;
58+
}
59+
60+
double operator()(const LeafNode<double>& node) override {
61+
callCount += 1;
62+
return node.value;
63+
}
64+
65+
double operator()(const AddNode<double>& node) override {
66+
// Recursive calls use the public `process` method from the base class
67+
// to ensure caching is applied at every step.
68+
callCount += 1;
69+
return this->process(node.left) + this->process(node.right);
70+
}
71+
72+
double operator()(const MultiplyNode<double>& node) override {
73+
callCount += 1;
74+
return this->process(node.left) * this->process(node.right);
75+
}
76+
77+
double operator()(const PowerNode<double>& node) override {
78+
callCount += 1;
79+
return std::pow(this->process(node.base), node.exponent);
80+
}
81+
};
82+
83+
TEST(ArithmeticDagTest, TestPrint) {
84+
auto root = StringLeavedDag::mul(
85+
StringLeavedDag::add(StringLeavedDag::leaf("x"),
86+
StringLeavedDag::constant(3.0)),
87+
StringLeavedDag::power(StringLeavedDag::leaf("y"), 2));
88+
89+
FlattenedStringVisitor visitor;
90+
std::string result = root->visit(visitor);
91+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
92+
}
93+
94+
TEST(ArithmeticDagTest, TestProperDag) {
95+
auto shared = StringLeavedDag::power(StringLeavedDag::leaf("y"), 2);
96+
auto root =
97+
StringLeavedDag::mul(StringLeavedDag::add(shared, shared), shared);
98+
99+
FlattenedStringVisitor visitor;
100+
std::string result = root->visit(visitor);
101+
EXPECT_EQ(result, "((y ^ 2) + (y ^ 2)) * (y ^ 2)");
102+
}
103+
104+
TEST(ArithmeticDagTest, TestEvaluationVisitor) {
105+
auto shared = DoubleLeavedDag::power(DoubleLeavedDag::leaf(2.0), 2);
106+
auto root = DoubleLeavedDag::mul(DoubleLeavedDag::add(shared, shared),
107+
DoubleLeavedDag::constant(3.0));
108+
109+
EvalVisitor visitor;
110+
double result = root->visit(visitor);
111+
EXPECT_EQ(result, 24.0);
112+
EXPECT_EQ(visitor.callCount, 5);
113+
}
114+
115+
} // namespace
116+
} // namespace heir
117+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticDag",
159+
srcs = ["ArithmeticDag.h"],
160+
hdrs = ["ArithmeticDag.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticDagTest",
165+
srcs = ["ArithmeticDagTest.cpp"],
166+
deps = [
167+
":ArithmeticDag",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)