Skip to content

Commit 61d5c55

Browse files
committed
Use CSE across stores during legalization.
1 parent 6565880 commit 61d5c55

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

src/LegalizeVectors.cpp

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,25 @@ class LiftExceedingVectors : public IRMutator {
8484
vector<pair<string, Expr>> lets;
8585
bool just_in_let_definition{false};
8686

87-
Expr visit(const Let *op) override {
88-
internal_error << "We don't want to process Lets. They should have all been converted to LetStmts.";
89-
return IRMutator::visit(op);
90-
}
91-
92-
Stmt visit(const LetStmt *op) override {
87+
template<typename LetOrLetStmt>
88+
auto visit_let_or_letstmt(const LetOrLetStmt *op) -> decltype(op->body) {
9389
just_in_let_definition = true;
9490
Expr def = mutate(op->value);
9591
just_in_let_definition = false;
9692

97-
Stmt body = mutate(op->body);
93+
decltype(op->body) body = mutate(op->body);
9894
if (def.same_as(op->value) && body.same_as(op->body)) {
9995
return op;
10096
}
101-
return LetStmt::make(op->name, std::move(def), std::move(body));
97+
return LetOrLetStmt::make(op->name, std::move(def), std::move(body));
98+
}
99+
100+
Expr visit(const Let *op) override {
101+
return visit_let_or_letstmt(op);
102+
}
103+
104+
Stmt visit(const LetStmt *op) override {
105+
return visit_let_or_letstmt(op);
102106
}
103107

104108
Expr visit(const Call *op) override {
@@ -206,7 +210,7 @@ class LegalizeVectors : public IRMutator {
206210
// First mark this Let as sliceable before mutating the body:
207211
ScopedBinding<> vector_is_slicable(sliceable_vectors, op->name);
208212

209-
Stmt body = mutate(op->body);
213+
auto body = mutate(op->body);
210214
// Here we know which requested vector variable slices should be created for the body of the Let/LetStmt to work.
211215

212216
if (std::vector<VectorSlice> *reqs = requested_slices.shallow_find(op->name)) {
@@ -228,8 +232,8 @@ class LegalizeVectors : public IRMutator {
228232
}
229233

230234
Expr visit(const Let *op) override {
231-
// TODO is this still true?
232-
internal_error << "Lets should have been lifted into LetStmts.";
235+
bool exceeds_lanecount = op->value.type().lanes() > max_lanes;
236+
internal_assert(!exceeds_lanecount) << "All illegal Let's should have been converted to LetStmts";
233237
return IRMutator::visit(op);
234238
}
235239

@@ -238,20 +242,61 @@ class LegalizeVectors : public IRMutator {
238242
if (exceeds_lanecount) {
239243
// Split up in multiple stores
240244
int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes;
245+
246+
std::vector<Expr> bundle_args;
247+
bundle_args.reserve(num_vecs * 3);
248+
249+
// Break up the index, predicate, and value of the Store into legal chunks.
250+
for (int i = 0; i < num_vecs; ++i) {
251+
int lane_start = i * max_lanes;
252+
int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes);
253+
254+
// Pack them in a known order: rhs, index, predicate
255+
bundle_args.push_back(extract_lanes(op->value, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices));
256+
bundle_args.push_back(extract_lanes(op->index, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices));
257+
bundle_args.push_back(extract_lanes(op->predicate, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices));
258+
}
259+
260+
// Run CSE on the joint bundle
261+
Expr joint_bundle = Call::make(Int(32), Call::bundle, bundle_args, Call::PureIntrinsic);
262+
joint_bundle = common_subexpression_elimination(joint_bundle);
263+
264+
// Peel off the `Let` expressions introduced by the CSE pass
265+
std::vector<std::pair<std::string, Expr>> let_bindings;
266+
while (const Let *let = joint_bundle.as<Let>()) {
267+
let_bindings.emplace_back(let->name, let->value);
268+
joint_bundle = let->body;
269+
}
270+
271+
// Destructure the bundle to get our optimized expressions
272+
const Call *struct_call = joint_bundle.as<Call>();
273+
internal_assert(struct_call && struct_call->is_intrinsic(Call::bundle))
274+
<< "Expected the CSE bundle to remain a bundle Call.";
275+
276+
// Construct the legal stores with the CSE'd expressions
241277
std::vector<Stmt> assignments;
242278
assignments.reserve(num_vecs);
243279
for (int i = 0; i < num_vecs; ++i) {
244280
int lane_start = i * max_lanes;
245-
int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes);
246281

247-
Expr rhs = extract_lanes(op->value, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices);
248-
Expr index = extract_lanes(op->index, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices);
249-
Expr predictate = extract_lanes(op->predicate, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices);
282+
// Unpack in the same order we packed them
283+
Expr rhs = struct_call->args[i * 3 + 0];
284+
Expr index = struct_call->args[i * 3 + 1];
285+
Expr predicate = struct_call->args[i * 3 + 2];
286+
250287
assignments.push_back(Store::make(
251288
op->name, std::move(rhs), std::move(index),
252-
op->param, std::move(predictate), op->alignment + lane_start));
289+
op->param, std::move(predicate), op->alignment + lane_start));
253290
}
291+
254292
Stmt result = Block::make(assignments);
293+
294+
// Wrap the block in LetStmts to properly scope all shared expressions
295+
// Iterate backwards to build the LetStmt tree from the inside out.
296+
for (auto &let : reverse_view(let_bindings)) {
297+
result = LetStmt::make(let.first, let.second, result);
298+
}
299+
255300
debug(3) << "Legalized store " << Stmt(op) << " => " << result << "\n";
256301
return result;
257302
}

0 commit comments

Comments
 (0)