diff --git a/src/xir/passes/mem2reg.cpp b/src/xir/passes/mem2reg.cpp index b41b29b97..82d8da72d 100644 --- a/src/xir/passes/mem2reg.cpp +++ b/src/xir/passes/mem2reg.cpp @@ -151,7 +151,7 @@ struct PhiInsertionAndRenaming { inserted.emplace_back(phi); info.inserted_phi_instructions.emplace(phi); // update the block-out value (note: we will overwrite it later if the block contains a store) - out_values.emplace(fb, phi); + out_values[fb] = phi; // replace the load instructions in the same block with the new phi node if (auto use_iter = analysis.use_blocks.find(fb); use_iter != analysis.use_blocks.end()) { replace_load_with_value(use_iter->second, phi, info); @@ -164,7 +164,7 @@ struct PhiInsertionAndRenaming { } // overwrite the block-out values with the store values for (auto [def_block, store] : analysis.def_blocks) { - out_values.emplace(def_block, store->value()); + out_values[def_block] = store->value(); } // each of the use blocks must be dominated by some def/phi block, or it must contain undefined value for (auto [use_block, load_inst] : analysis.use_blocks) { @@ -245,20 +245,14 @@ static void simplify_single_block_store_load(AllocaInst *inst, AllocaStoreLoadSe break; } case DerivedInstructionTag::STORE: { - auto store_inst = static_cast(store_or_load); - auto stored_value = store_inst->value(); - LUISA_DEBUG_ASSERT(stored_value != nullptr, "Invalid store."); - if (last_store != nullptr) {// we have overwritten the last store so remove it + // we have overwritten the last store so remove it if any + if (last_store != nullptr) { remove_store(last_store, info); } - // detect for redundant store where a previously loaded value is stored back - if (stored_value == last_value) { - remove_store(store_inst, info); - last_store = nullptr; - } else {// update the last store and value - last_store = store_inst; - last_value = store_inst->value(); - } + // record this store + last_store = static_cast(store_or_load); + last_value = last_store->value(); + LUISA_DEBUG_ASSERT(last_value != nullptr, "Invalid store."); break; } default: LUISA_ERROR_WITH_LOCATION("Invalid instruction."); diff --git a/src/xir/tests/test_mem2reg.cpp b/src/xir/tests/test_mem2reg.cpp index 91857e435..0ed438faf 100644 --- a/src/xir/tests/test_mem2reg.cpp +++ b/src/xir/tests/test_mem2reg.cpp @@ -11,23 +11,21 @@ int main(int argc, char *argv[]) { using namespace luisa::compute; auto shader = device.compile<1>([&](UInt n) noexcept { - auto b = compute::detail::FunctionBuilder::current(); - auto t = Type::of(); - auto x = b->local(t); - auto zero = b->literal(t, 0u); - auto one = b->literal(t, 1u); - $if (n == 0u) { - } - $elif (n == 1u) { - // b->assign(x, one); - } - $else { - b->assign(x, zero); + auto z = def(1u); + $loop { + UInt x; + x = 2u; + x = 2u; + n -= 1u; + $if (n == 1u) { + $break; + }; + z *= x; }; - buffer->write(0u, def(x)); + buffer->write(0u, z); }); - uint result = 0u; + auto result = 0u; stream << shader(10u).dispatch(1u) << buffer.copy_to(&result) << synchronize();