|
8 | 8 | namespace Halide { |
9 | 9 | namespace Internal { |
10 | 10 |
|
| 11 | +namespace { |
| 12 | +// ========================================================================= |
| 13 | +// ✨ CLEVER TEMPLATING ✨ |
| 14 | +// This SFINAE trick checks if `T::make` can be invoked with `Args...`. |
| 15 | +// It will trigger a static_assert if you forget an argument or pass the |
| 16 | +// wrong field types, completely preventing generated code compile errors! |
| 17 | +// ========================================================================= |
| 18 | +template<typename T, typename... Args> |
| 19 | +static constexpr auto check_make_args(Args &&...args) |
| 20 | + -> decltype(T::make(std::forward<Args>(args)...), std::true_type{}) { |
| 21 | + return std::true_type{}; |
| 22 | +} |
| 23 | + |
| 24 | +template<typename T, typename... Args> |
| 25 | +static constexpr std::false_type check_make_args(...) { |
| 26 | + return std::false_type{}; |
| 27 | +} |
| 28 | + |
| 29 | +} // namespace |
| 30 | + |
| 31 | +template<typename T> |
| 32 | +std::string IRGraphCXXPrinter::to_cpp_arg(const T &x) { |
| 33 | + if constexpr (std::is_arithmetic_v<T>) { |
| 34 | + return std::to_string(x); |
| 35 | + } else { |
| 36 | + internal_error << "Not supported to print"; |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +template<> |
| 41 | +std::string IRGraphCXXPrinter::to_cpp_arg<Expr>(const Expr &e) { |
| 42 | + if (!e.defined()) { |
| 43 | + return "Expr()"; |
| 44 | + } |
| 45 | + include(e); |
| 46 | + return node_names.at(e.get()); |
| 47 | +} |
| 48 | + |
| 49 | +template<> |
| 50 | +std::string IRGraphCXXPrinter::to_cpp_arg<Stmt>(const Stmt &s) { |
| 51 | + if (!s.defined()) { |
| 52 | + return "Stmt()"; |
| 53 | + } |
| 54 | + include(s); |
| 55 | + return node_names.at(s.get()); |
| 56 | +} |
| 57 | + |
| 58 | +template<> |
| 59 | +std::string IRGraphCXXPrinter::to_cpp_arg<Range>(const Range &r) { |
| 60 | + include(r.min); |
| 61 | + include(r.extent); |
| 62 | + return "Range(" + node_names.at(r.min.get()) + ", " + node_names.at(r.extent.get()) + ")"; |
| 63 | +} |
| 64 | + |
| 65 | +template<> |
| 66 | +std::string IRGraphCXXPrinter::to_cpp_arg<std::string>(const std::string &s) { |
| 67 | + return "\"" + s + "\""; |
| 68 | +} |
| 69 | +template<> |
| 70 | +std::string IRGraphCXXPrinter::to_cpp_arg<ForType>(const ForType &f) { |
| 71 | + switch (f) { |
| 72 | + case ForType::Serial: |
| 73 | + return "ForType::Serial"; |
| 74 | + case ForType::Parallel: |
| 75 | + return "ForType::Parallel"; |
| 76 | + case ForType::Vectorized: |
| 77 | + return "ForType::Vectorized"; |
| 78 | + case ForType::Unrolled: |
| 79 | + return "ForType::Unrolled"; |
| 80 | + case ForType::Extern: |
| 81 | + return "ForType::Extern"; |
| 82 | + case ForType::GPUBlock: |
| 83 | + return "ForType::GPUBlock"; |
| 84 | + case ForType::GPUThread: |
| 85 | + return "ForType::GPUThread"; |
| 86 | + case ForType::GPULane: |
| 87 | + return "ForType::GPULane"; |
| 88 | + default: |
| 89 | + return "ForType::Serial"; |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +template<> |
| 94 | +std::string IRGraphCXXPrinter::to_cpp_arg<VectorReduce::Operator>(const VectorReduce::Operator &op) { |
| 95 | + switch (op) { |
| 96 | + case VectorReduce::Add: |
| 97 | + return "VectorReduce::Add"; |
| 98 | + case VectorReduce::SaturatingAdd: |
| 99 | + return "VectorReduce::SaturatingAdd"; |
| 100 | + case VectorReduce::Mul: |
| 101 | + return "VectorReduce::Mul"; |
| 102 | + case VectorReduce::Min: |
| 103 | + return "VectorReduce::Min"; |
| 104 | + case VectorReduce::Max: |
| 105 | + return "VectorReduce::Max"; |
| 106 | + case VectorReduce::And: |
| 107 | + return "VectorReduce::And"; |
| 108 | + case VectorReduce::Or: |
| 109 | + return "VectorReduce::Or"; |
| 110 | + } |
| 111 | + internal_error << "Invalid VectorReduce"; |
| 112 | +} |
| 113 | + |
| 114 | +template<> |
| 115 | +std::string IRGraphCXXPrinter::to_cpp_arg<Type>(const Type &t) { |
| 116 | + std::ostringstream oss; |
| 117 | + oss << "Type(Type::" |
| 118 | + << (t.is_int() ? "Int" : t.is_uint() ? "UInt" : |
| 119 | + t.is_float() ? "Float" : |
| 120 | + t.is_bfloat() ? "BFloat" : |
| 121 | + "Handle") |
| 122 | + << ", " << t.bits() << ", " << t.lanes() << ")"; |
| 123 | + return oss.str(); |
| 124 | +} |
| 125 | + |
| 126 | +template<> |
| 127 | +std::string IRGraphCXXPrinter::to_cpp_arg<ModulusRemainder>(const ModulusRemainder &align) { |
| 128 | + return "ModulusRemainder(" + std::to_string(align.modulus) + ", " + std::to_string(align.remainder) + ")"; |
| 129 | +} |
| 130 | + |
| 131 | +template<typename T> |
| 132 | +std::string IRGraphCXXPrinter::to_cpp_arg(const std::vector<T> &vec) { |
| 133 | + std::string res = "{"; |
| 134 | + for (size_t i = 0; i < vec.size(); ++i) { |
| 135 | + res += to_cpp_arg(vec[i]); |
| 136 | + if (i + 1 < vec.size()) { |
| 137 | + res += ", "; |
| 138 | + } |
| 139 | + } |
| 140 | + res += "}"; |
| 141 | + return res; |
| 142 | +} |
| 143 | + |
| 144 | +template<typename T, typename... Args> |
| 145 | +void IRGraphCXXPrinter::emit_node(const char *node_type_str, const T *op, Args &&...args) { |
| 146 | + if (node_names.count(op)) { |
| 147 | + return; |
| 148 | + } |
| 149 | + |
| 150 | + static_assert(decltype(check_make_args<T>(std::forward<Args>(args)...))::value, |
| 151 | + "Arguments extracted for printer do not match any T::make() signature! " |
| 152 | + "Check your VISIT_NODE macro arguments."); |
| 153 | + |
| 154 | + // Evaluate arguments post-order to build dependencies. |
| 155 | + // (C++11 guarantees left-to-right evaluation in brace-init lists) |
| 156 | + std::vector<std::string> printed_args = {to_cpp_arg(args)...}; |
| 157 | + |
| 158 | + // Generate the actual C++ code |
| 159 | + bool is_expr = std::is_base_of_v<BaseExprNode, T>; |
| 160 | + std::string var_name = (is_expr ? "expr_" : "stmt_") + std::to_string(var_counter++); |
| 161 | + |
| 162 | + os << (is_expr ? "Expr " : "Stmt ") << var_name << " = " << node_type_str << "::make("; |
| 163 | + for (size_t i = 0; i < printed_args.size(); ++i) { |
| 164 | + os << printed_args[i] << (i + 1 == printed_args.size() ? "" : ", "); |
| 165 | + } |
| 166 | + os << ");\n"; |
| 167 | + |
| 168 | + node_names[op] = var_name; |
| 169 | +} |
| 170 | + |
| 171 | +// Macro handles mapping the IR node pointer to our clever template. |
| 172 | +#define VISIT_NODE(NodeType, ...) \ |
| 173 | + void IRGraphCXXPrinter::visit(const NodeType *op) { \ |
| 174 | + IRGraphVisitor::visit(op); \ |
| 175 | + emit_node<NodeType>(#NodeType, op, __VA_ARGS__); \ |
| 176 | + } |
| 177 | + |
| 178 | +// --- 1. Core / Primitive Values --- |
| 179 | +VISIT_NODE(IntImm, op->type, op->value) |
| 180 | +VISIT_NODE(UIntImm, op->type, op->value) |
| 181 | +VISIT_NODE(FloatImm, op->type, op->value) |
| 182 | +VISIT_NODE(StringImm, op->value) |
| 183 | + |
| 184 | +// --- 2. Variable & Broadcast --- |
| 185 | +VISIT_NODE(Variable, op->type, op->name /*, op->image, op->param, op->reduction_domain */) |
| 186 | +VISIT_NODE(Broadcast, op->value, op->lanes) |
| 187 | + |
| 188 | +// --- 3. Binary & Unary Math Nodes --- |
| 189 | +VISIT_NODE(Add, op->a, op->b) |
| 190 | +VISIT_NODE(Sub, op->a, op->b) |
| 191 | +VISIT_NODE(Mod, op->a, op->b) |
| 192 | +VISIT_NODE(Mul, op->a, op->b) |
| 193 | +VISIT_NODE(Div, op->a, op->b) |
| 194 | +VISIT_NODE(Min, op->a, op->b) |
| 195 | +VISIT_NODE(Max, op->a, op->b) |
| 196 | +VISIT_NODE(EQ, op->a, op->b) |
| 197 | +VISIT_NODE(NE, op->a, op->b) |
| 198 | +VISIT_NODE(LT, op->a, op->b) |
| 199 | +VISIT_NODE(LE, op->a, op->b) |
| 200 | +VISIT_NODE(GT, op->a, op->b) |
| 201 | +VISIT_NODE(GE, op->a, op->b) |
| 202 | +VISIT_NODE(And, op->a, op->b) |
| 203 | +VISIT_NODE(Or, op->a, op->b) |
| 204 | +VISIT_NODE(Not, op->a) |
| 205 | + |
| 206 | +// --- 4. Casts & Shuffles --- |
| 207 | +VISIT_NODE(Cast, op->type, op->value) |
| 208 | +VISIT_NODE(Reinterpret, op->type, op->value) |
| 209 | +VISIT_NODE(Shuffle, op->vectors, op->indices) |
| 210 | + |
| 211 | +// --- 5. Complex Expressions --- |
| 212 | +VISIT_NODE(Select, op->condition, op->true_value, op->false_value) |
| 213 | +VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment) |
| 214 | +VISIT_NODE(Ramp, op->base, op->stride, op->lanes) |
| 215 | +VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param) |
| 216 | +VISIT_NODE(Let, op->name, op->value, op->body) |
| 217 | +VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) |
| 218 | + |
| 219 | +// --- 6. Core Statements --- |
| 220 | +VISIT_NODE(LetStmt, op->name, op->value, op->body) |
| 221 | +VISIT_NODE(AssertStmt, op->condition, op->message) |
| 222 | +VISIT_NODE(Evaluate, op->value) |
| 223 | +VISIT_NODE(Block, op->first, op->rest) |
| 224 | +VISIT_NODE(IfThenElse, op->condition, op->then_case, op->else_case) |
| 225 | +VISIT_NODE(For, op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, op->body) |
| 226 | + |
| 227 | +// --- 7. Memory / Buffer Operations --- |
| 228 | +VISIT_NODE(Store, op->name, op->value, op->index, op->param, op->predicate, op->alignment) |
| 229 | +VISIT_NODE(Provide, op->name, op->values, op->args, op->predicate) |
| 230 | +VISIT_NODE(Allocate, op->name, op->type, op->memory_type, op->extents, op->condition, op->body, op->new_expr, op->free_function) |
| 231 | +VISIT_NODE(Free, op->name) |
| 232 | +VISIT_NODE(Realize, op->name, op->types, op->memory_type, op->bounds, op->condition, op->body) |
| 233 | +VISIT_NODE(Prefetch, op->name, op->types, op->bounds, op->prefetch, op->condition, op->body) |
| 234 | +VISIT_NODE(HoistedStorage, op->name, op->body) |
| 235 | + |
| 236 | +// --- 8. Concurrency & Sync --- |
| 237 | +VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body) |
| 238 | +VISIT_NODE(Acquire, op->semaphore, op->count, op->body) |
| 239 | +VISIT_NODE(Fork, op->first, op->rest) |
| 240 | +VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) |
| 241 | + |
11 | 242 | void IRGraphCXXPrinter::test() { |
12 | 243 | // This: |
13 | 244 | Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); |
@@ -41,9 +272,7 @@ void IRGraphCXXPrinter::test() { |
41 | 272 | // Now let's see if it matches: |
42 | 273 | internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n" |
43 | 274 | << e << "\n\n" |
44 | | - << expr_19 << "\n"; |
45 | | - |
46 | | - // Here is a bad typo for Alex who likes progamming. Above is a badly intented line. Two typos? |
| 275 | + << expr_19 << "\n"; |
47 | 276 | } |
48 | 277 | } // namespace Internal |
49 | 278 | } // namespace Halide |
0 commit comments