Skip to content

Commit 9b428da

Browse files
committed
Moved implementation of the IRGraphCXXPrinter to the cpp file.
1 parent a40180b commit 9b428da

File tree

2 files changed

+284
-244
lines changed

2 files changed

+284
-244
lines changed

src/IRGraphCXXPrinter.cpp

Lines changed: 232 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,237 @@
88
namespace Halide {
99
namespace Internal {
1010

11+
namespace {
12+
// =========================================================================
13+
// ✨ CLEVER TEMPLATING ✨
14+
// This SFINAE trick checks if `T::make` can be invoked with `Args...`.
15+
// It will trigger a static_assert if you forget an argument or pass the
16+
// wrong field types, completely preventing generated code compile errors!
17+
// =========================================================================
18+
template<typename T, typename... Args>
19+
static constexpr auto check_make_args(Args &&...args)
20+
-> decltype(T::make(std::forward<Args>(args)...), std::true_type{}) {
21+
return std::true_type{};
22+
}
23+
24+
template<typename T, typename... Args>
25+
static constexpr std::false_type check_make_args(...) {
26+
return std::false_type{};
27+
}
28+
29+
} // namespace
30+
31+
template<typename T>
32+
std::string IRGraphCXXPrinter::to_cpp_arg(const T &x) {
33+
if constexpr (std::is_arithmetic_v<T>) {
34+
return std::to_string(x);
35+
} else {
36+
internal_error << "Not supported to print";
37+
}
38+
}
39+
40+
template<>
41+
std::string IRGraphCXXPrinter::to_cpp_arg<Expr>(const Expr &e) {
42+
if (!e.defined()) {
43+
return "Expr()";
44+
}
45+
include(e);
46+
return node_names.at(e.get());
47+
}
48+
49+
template<>
50+
std::string IRGraphCXXPrinter::to_cpp_arg<Stmt>(const Stmt &s) {
51+
if (!s.defined()) {
52+
return "Stmt()";
53+
}
54+
include(s);
55+
return node_names.at(s.get());
56+
}
57+
58+
template<>
59+
std::string IRGraphCXXPrinter::to_cpp_arg<Range>(const Range &r) {
60+
include(r.min);
61+
include(r.extent);
62+
return "Range(" + node_names.at(r.min.get()) + ", " + node_names.at(r.extent.get()) + ")";
63+
}
64+
65+
template<>
66+
std::string IRGraphCXXPrinter::to_cpp_arg<std::string>(const std::string &s) {
67+
return "\"" + s + "\"";
68+
}
69+
template<>
70+
std::string IRGraphCXXPrinter::to_cpp_arg<ForType>(const ForType &f) {
71+
switch (f) {
72+
case ForType::Serial:
73+
return "ForType::Serial";
74+
case ForType::Parallel:
75+
return "ForType::Parallel";
76+
case ForType::Vectorized:
77+
return "ForType::Vectorized";
78+
case ForType::Unrolled:
79+
return "ForType::Unrolled";
80+
case ForType::Extern:
81+
return "ForType::Extern";
82+
case ForType::GPUBlock:
83+
return "ForType::GPUBlock";
84+
case ForType::GPUThread:
85+
return "ForType::GPUThread";
86+
case ForType::GPULane:
87+
return "ForType::GPULane";
88+
default:
89+
return "ForType::Serial";
90+
}
91+
}
92+
93+
template<>
94+
std::string IRGraphCXXPrinter::to_cpp_arg<VectorReduce::Operator>(const VectorReduce::Operator &op) {
95+
switch (op) {
96+
case VectorReduce::Add:
97+
return "VectorReduce::Add";
98+
case VectorReduce::SaturatingAdd:
99+
return "VectorReduce::SaturatingAdd";
100+
case VectorReduce::Mul:
101+
return "VectorReduce::Mul";
102+
case VectorReduce::Min:
103+
return "VectorReduce::Min";
104+
case VectorReduce::Max:
105+
return "VectorReduce::Max";
106+
case VectorReduce::And:
107+
return "VectorReduce::And";
108+
case VectorReduce::Or:
109+
return "VectorReduce::Or";
110+
}
111+
internal_error << "Invalid VectorReduce";
112+
}
113+
114+
template<>
115+
std::string IRGraphCXXPrinter::to_cpp_arg<Type>(const Type &t) {
116+
std::ostringstream oss;
117+
oss << "Type(Type::"
118+
<< (t.is_int() ? "Int" : t.is_uint() ? "UInt" :
119+
t.is_float() ? "Float" :
120+
t.is_bfloat() ? "BFloat" :
121+
"Handle")
122+
<< ", " << t.bits() << ", " << t.lanes() << ")";
123+
return oss.str();
124+
}
125+
126+
template<>
127+
std::string IRGraphCXXPrinter::to_cpp_arg<ModulusRemainder>(const ModulusRemainder &align) {
128+
return "ModulusRemainder(" + std::to_string(align.modulus) + ", " + std::to_string(align.remainder) + ")";
129+
}
130+
131+
template<typename T>
132+
std::string IRGraphCXXPrinter::to_cpp_arg(const std::vector<T> &vec) {
133+
std::string res = "{";
134+
for (size_t i = 0; i < vec.size(); ++i) {
135+
res += to_cpp_arg(vec[i]);
136+
if (i + 1 < vec.size()) {
137+
res += ", ";
138+
}
139+
}
140+
res += "}";
141+
return res;
142+
}
143+
144+
template<typename T, typename... Args>
145+
void IRGraphCXXPrinter::emit_node(const char *node_type_str, const T *op, Args &&...args) {
146+
if (node_names.count(op)) {
147+
return;
148+
}
149+
150+
static_assert(decltype(check_make_args<T>(std::forward<Args>(args)...))::value,
151+
"Arguments extracted for printer do not match any T::make() signature! "
152+
"Check your VISIT_NODE macro arguments.");
153+
154+
// Evaluate arguments post-order to build dependencies.
155+
// (C++11 guarantees left-to-right evaluation in brace-init lists)
156+
std::vector<std::string> printed_args = {to_cpp_arg(args)...};
157+
158+
// Generate the actual C++ code
159+
bool is_expr = std::is_base_of_v<BaseExprNode, T>;
160+
std::string var_name = (is_expr ? "expr_" : "stmt_") + std::to_string(var_counter++);
161+
162+
os << (is_expr ? "Expr " : "Stmt ") << var_name << " = " << node_type_str << "::make(";
163+
for (size_t i = 0; i < printed_args.size(); ++i) {
164+
os << printed_args[i] << (i + 1 == printed_args.size() ? "" : ", ");
165+
}
166+
os << ");\n";
167+
168+
node_names[op] = var_name;
169+
}
170+
171+
// Macro handles mapping the IR node pointer to our clever template.
172+
#define VISIT_NODE(NodeType, ...) \
173+
void IRGraphCXXPrinter::visit(const NodeType *op) { \
174+
IRGraphVisitor::visit(op); \
175+
emit_node<NodeType>(#NodeType, op, __VA_ARGS__); \
176+
}
177+
178+
// --- 1. Core / Primitive Values ---
179+
VISIT_NODE(IntImm, op->type, op->value)
180+
VISIT_NODE(UIntImm, op->type, op->value)
181+
VISIT_NODE(FloatImm, op->type, op->value)
182+
VISIT_NODE(StringImm, op->value)
183+
184+
// --- 2. Variable & Broadcast ---
185+
VISIT_NODE(Variable, op->type, op->name /*, op->image, op->param, op->reduction_domain */)
186+
VISIT_NODE(Broadcast, op->value, op->lanes)
187+
188+
// --- 3. Binary & Unary Math Nodes ---
189+
VISIT_NODE(Add, op->a, op->b)
190+
VISIT_NODE(Sub, op->a, op->b)
191+
VISIT_NODE(Mod, op->a, op->b)
192+
VISIT_NODE(Mul, op->a, op->b)
193+
VISIT_NODE(Div, op->a, op->b)
194+
VISIT_NODE(Min, op->a, op->b)
195+
VISIT_NODE(Max, op->a, op->b)
196+
VISIT_NODE(EQ, op->a, op->b)
197+
VISIT_NODE(NE, op->a, op->b)
198+
VISIT_NODE(LT, op->a, op->b)
199+
VISIT_NODE(LE, op->a, op->b)
200+
VISIT_NODE(GT, op->a, op->b)
201+
VISIT_NODE(GE, op->a, op->b)
202+
VISIT_NODE(And, op->a, op->b)
203+
VISIT_NODE(Or, op->a, op->b)
204+
VISIT_NODE(Not, op->a)
205+
206+
// --- 4. Casts & Shuffles ---
207+
VISIT_NODE(Cast, op->type, op->value)
208+
VISIT_NODE(Reinterpret, op->type, op->value)
209+
VISIT_NODE(Shuffle, op->vectors, op->indices)
210+
211+
// --- 5. Complex Expressions ---
212+
VISIT_NODE(Select, op->condition, op->true_value, op->false_value)
213+
VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment)
214+
VISIT_NODE(Ramp, op->base, op->stride, op->lanes)
215+
VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param)
216+
VISIT_NODE(Let, op->name, op->value, op->body)
217+
VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes())
218+
219+
// --- 6. Core Statements ---
220+
VISIT_NODE(LetStmt, op->name, op->value, op->body)
221+
VISIT_NODE(AssertStmt, op->condition, op->message)
222+
VISIT_NODE(Evaluate, op->value)
223+
VISIT_NODE(Block, op->first, op->rest)
224+
VISIT_NODE(IfThenElse, op->condition, op->then_case, op->else_case)
225+
VISIT_NODE(For, op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, op->body)
226+
227+
// --- 7. Memory / Buffer Operations ---
228+
VISIT_NODE(Store, op->name, op->value, op->index, op->param, op->predicate, op->alignment)
229+
VISIT_NODE(Provide, op->name, op->values, op->args, op->predicate)
230+
VISIT_NODE(Allocate, op->name, op->type, op->memory_type, op->extents, op->condition, op->body, op->new_expr, op->free_function)
231+
VISIT_NODE(Free, op->name)
232+
VISIT_NODE(Realize, op->name, op->types, op->memory_type, op->bounds, op->condition, op->body)
233+
VISIT_NODE(Prefetch, op->name, op->types, op->bounds, op->prefetch, op->condition, op->body)
234+
VISIT_NODE(HoistedStorage, op->name, op->body)
235+
236+
// --- 8. Concurrency & Sync ---
237+
VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body)
238+
VISIT_NODE(Acquire, op->semaphore, op->count, op->body)
239+
VISIT_NODE(Fork, op->first, op->rest)
240+
VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body)
241+
11242
void IRGraphCXXPrinter::test() {
12243
// This:
13244
Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8));
@@ -41,9 +272,7 @@ void IRGraphCXXPrinter::test() {
41272
// Now let's see if it matches:
42273
internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n"
43274
<< e << "\n\n"
44-
<< expr_19 << "\n";
45-
46-
// Here is a bad typo for Alex who likes progamming. Above is a badly intented line. Two typos?
275+
<< expr_19 << "\n";
47276
}
48277
} // namespace Internal
49278
} // namespace Halide

0 commit comments

Comments
 (0)