Skip to content

Commit

Permalink
fix mem2reg pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Feb 10, 2025
1 parent 4788c1b commit acf4da9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
22 changes: 8 additions & 14 deletions src/xir/passes/mem2reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ struct PhiInsertionAndRenaming {
inserted.emplace_back(phi);
info.inserted_phi_instructions.emplace(phi);
// update the block-out value (note: we will overwrite it later if the block contains a store)
out_values.emplace(fb, phi);
out_values[fb] = phi;
// replace the load instructions in the same block with the new phi node
if (auto use_iter = analysis.use_blocks.find(fb); use_iter != analysis.use_blocks.end()) {
replace_load_with_value(use_iter->second, phi, info);
Expand All @@ -164,7 +164,7 @@ struct PhiInsertionAndRenaming {
}
// overwrite the block-out values with the store values
for (auto [def_block, store] : analysis.def_blocks) {
out_values.emplace(def_block, store->value());
out_values[def_block] = store->value();
}
// each of the use blocks must be dominated by some def/phi block, or it must contain undefined value
for (auto [use_block, load_inst] : analysis.use_blocks) {
Expand Down Expand Up @@ -245,20 +245,14 @@ static void simplify_single_block_store_load(AllocaInst *inst, AllocaStoreLoadSe
break;
}
case DerivedInstructionTag::STORE: {
auto store_inst = static_cast<StoreInst *>(store_or_load);
auto stored_value = store_inst->value();
LUISA_DEBUG_ASSERT(stored_value != nullptr, "Invalid store.");
if (last_store != nullptr) {// we have overwritten the last store so remove it
// we have overwritten the last store so remove it if any
if (last_store != nullptr) {
remove_store(last_store, info);
}
// detect for redundant store where a previously loaded value is stored back
if (stored_value == last_value) {
remove_store(store_inst, info);
last_store = nullptr;
} else {// update the last store and value
last_store = store_inst;
last_value = store_inst->value();
}
// record this store
last_store = static_cast<StoreInst *>(store_or_load);
last_value = last_store->value();
LUISA_DEBUG_ASSERT(last_value != nullptr, "Invalid store.");
break;
}
default: LUISA_ERROR_WITH_LOCATION("Invalid instruction.");
Expand Down
26 changes: 12 additions & 14 deletions src/xir/tests/test_mem2reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@ int main(int argc, char *argv[]) {

using namespace luisa::compute;
auto shader = device.compile<1>([&](UInt n) noexcept {
auto b = compute::detail::FunctionBuilder::current();
auto t = Type::of<uint>();
auto x = b->local(t);
auto zero = b->literal(t, 0u);
auto one = b->literal(t, 1u);
$if (n == 0u) {
}
$elif (n == 1u) {
// b->assign(x, one);
}
$else {
b->assign(x, zero);
auto z = def(1u);
$loop {
UInt x;
x = 2u;
x = 2u;
n -= 1u;
$if (n == 1u) {
$break;
};
z *= x;
};
buffer->write(0u, def<uint>(x));
buffer->write(0u, z);
});

uint result = 0u;
auto result = 0u;
stream << shader(10u).dispatch(1u)
<< buffer.copy_to(&result)
<< synchronize();
Expand Down

0 comments on commit acf4da9

Please sign in to comment.