@@ -84,21 +84,25 @@ class LiftExceedingVectors : public IRMutator {
8484 vector<pair<string, Expr>> lets;
8585 bool just_in_let_definition{false };
8686
87- Expr visit (const Let *op) override {
88- internal_error << " We don't want to process Lets. They should have all been converted to LetStmts." ;
89- return IRMutator::visit (op);
90- }
91-
92- Stmt visit (const LetStmt *op) override {
87+ template <typename LetOrLetStmt>
88+ auto visit_let_or_letstmt (const LetOrLetStmt *op) -> decltype(op->body) {
9389 just_in_let_definition = true ;
9490 Expr def = mutate (op->value );
9591 just_in_let_definition = false ;
9692
97- Stmt body = mutate (op->body );
93+ decltype (op-> body ) body = mutate (op->body );
9894 if (def.same_as (op->value ) && body.same_as (op->body )) {
9995 return op;
10096 }
101- return LetStmt::make (op->name , std::move (def), std::move (body));
97+ return LetOrLetStmt::make (op->name , std::move (def), std::move (body));
98+ }
99+
100+ Expr visit (const Let *op) override {
101+ return visit_let_or_letstmt (op);
102+ }
103+
104+ Stmt visit (const LetStmt *op) override {
105+ return visit_let_or_letstmt (op);
102106 }
103107
104108 Expr visit (const Call *op) override {
@@ -206,7 +210,7 @@ class LegalizeVectors : public IRMutator {
206210 // First mark this Let as sliceable before mutating the body:
207211 ScopedBinding<> vector_is_slicable (sliceable_vectors, op->name );
208212
209- Stmt body = mutate (op->body );
213+ auto body = mutate (op->body );
210214 // Here we know which requested vector variable slices should be created for the body of the Let/LetStmt to work.
211215
212216 if (std::vector<VectorSlice> *reqs = requested_slices.shallow_find (op->name )) {
@@ -228,8 +232,8 @@ class LegalizeVectors : public IRMutator {
228232 }
229233
230234 Expr visit (const Let *op) override {
231- // TODO is this still true?
232- internal_error << " Lets should have been lifted into LetStmts. " ;
235+ bool exceeds_lanecount = op-> value . type (). lanes () > max_lanes;
236+ internal_assert (!exceeds_lanecount) << " All illegal Let's should have been converted to LetStmts" ;
233237 return IRMutator::visit (op);
234238 }
235239
@@ -238,20 +242,61 @@ class LegalizeVectors : public IRMutator {
238242 if (exceeds_lanecount) {
239243 // Split up in multiple stores
240244 int num_vecs = (op->index .type ().lanes () + max_lanes - 1 ) / max_lanes;
245+
246+ std::vector<Expr> bundle_args;
247+ bundle_args.reserve (num_vecs * 3 );
248+
249+ // Break up the index, predicate, and value of the Store into legal chunks.
250+ for (int i = 0 ; i < num_vecs; ++i) {
251+ int lane_start = i * max_lanes;
252+ int lane_count_for_vec = std::min (op->value .type ().lanes () - lane_start, max_lanes);
253+
254+ // Pack them in a known order: rhs, index, predicate
255+ bundle_args.push_back (extract_lanes (op->value , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices));
256+ bundle_args.push_back (extract_lanes (op->index , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices));
257+ bundle_args.push_back (extract_lanes (op->predicate , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices));
258+ }
259+
260+ // Run CSE on the joint bundle
261+ Expr joint_bundle = Call::make (Int (32 ), Call::bundle, bundle_args, Call::PureIntrinsic);
262+ joint_bundle = common_subexpression_elimination (joint_bundle);
263+
264+ // Peel off the `Let` expressions introduced by the CSE pass
265+ std::vector<std::pair<std::string, Expr>> let_bindings;
266+ while (const Let *let = joint_bundle.as <Let>()) {
267+ let_bindings.emplace_back (let->name , let->value );
268+ joint_bundle = let->body ;
269+ }
270+
271+ // Destructure the bundle to get our optimized expressions
272+ const Call *struct_call = joint_bundle.as <Call>();
273+ internal_assert (struct_call && struct_call->is_intrinsic (Call::bundle))
274+ << " Expected the CSE bundle to remain a bundle Call." ;
275+
276+ // Construct the legal stores with the CSE'd expressions
241277 std::vector<Stmt> assignments;
242278 assignments.reserve (num_vecs);
243279 for (int i = 0 ; i < num_vecs; ++i) {
244280 int lane_start = i * max_lanes;
245- int lane_count_for_vec = std::min (op->value .type ().lanes () - lane_start, max_lanes);
246281
247- Expr rhs = extract_lanes (op->value , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices);
248- Expr index = extract_lanes (op->index , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices);
249- Expr predictate = extract_lanes (op->predicate , lane_start, 1 , lane_count_for_vec, sliceable_vectors, requested_slices);
282+ // Unpack in the same order we packed them
283+ Expr rhs = struct_call->args [i * 3 + 0 ];
284+ Expr index = struct_call->args [i * 3 + 1 ];
285+ Expr predicate = struct_call->args [i * 3 + 2 ];
286+
250287 assignments.push_back (Store::make (
251288 op->name , std::move (rhs), std::move (index),
252- op->param , std::move (predictate ), op->alignment + lane_start));
289+ op->param , std::move (predicate ), op->alignment + lane_start));
253290 }
291+
254292 Stmt result = Block::make (assignments);
293+
294+ // Wrap the block in LetStmts to properly scope all shared expressions
295+ // Iterate backwards to build the LetStmt tree from the inside out.
296+ for (auto &let : reverse_view (let_bindings)) {
297+ result = LetStmt::make (let.first , let.second , result);
298+ }
299+
255300 debug (3 ) << " Legalized store " << Stmt (op) << " => " << result << " \n " ;
256301 return result;
257302 }
0 commit comments