From 38c9bf41e7145655d3d852aa1c7b46c77ab787e4 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 Jan 2025 19:34:36 +0800 Subject: [PATCH] add straightline load-to-load forwarding --- .../luisa/xir/passes/local_load_elimination.h | 5 + .../luisa/xir/passes/local_store_forward.h | 3 - src/backends/fallback/fallback_device.cpp | 4 - src/backends/fallback/fallback_shader.cpp | 4 + src/xir/passes/local_load_elimination.cpp | 121 ++++++++++++++++++ src/xir/passes/local_store_forward.cpp | 17 ++- 6 files changed, 141 insertions(+), 13 deletions(-) diff --git a/include/luisa/xir/passes/local_load_elimination.h b/include/luisa/xir/passes/local_load_elimination.h index 92a2887e8..f355a11b9 100644 --- a/include/luisa/xir/passes/local_load_elimination.h +++ b/include/luisa/xir/passes/local_load_elimination.h @@ -9,6 +9,11 @@ class LoadInst; class Function; class Module; +// This pass implements a simple local load elimination optimization. +// For each load instruction, if a recent load instruction with the same +// variable has happened without possible intervening stores, the load +// instruction can be replaced with the recent load instruction. + struct LocalLoadEliminationInfo { luisa::unordered_map eliminated_instructions; }; diff --git a/include/luisa/xir/passes/local_store_forward.h b/include/luisa/xir/passes/local_store_forward.h index 22f725e9e..aa26e1b53 100644 --- a/include/luisa/xir/passes/local_store_forward.h +++ b/include/luisa/xir/passes/local_store_forward.h @@ -15,9 +15,6 @@ class Module; // This pass is used to forward stores to loads for scalar variables // within straight-line basic blocks. It is a simple peephole optimization // that can be used to reduce the number of memory operations. -// Note: this pass does not remove the original store instructions. -// It only forwards the values to the loads. To remove the original -// store instructions, a DCE pass should be used after this pass. struct LocalStoreForwardInfo { luisa::unordered_map forwarded_instructions; diff --git a/src/backends/fallback/fallback_device.cpp b/src/backends/fallback/fallback_device.cpp index 38cd41fb7..e3bfead93 100644 --- a/src/backends/fallback/fallback_device.cpp +++ b/src/backends/fallback/fallback_device.cpp @@ -42,10 +42,6 @@ FallbackDevice::FallbackDevice(Context &&ctx) noexcept // embree _rtc_device = rtcNewDevice("frequency_level=simd128,isa=avx2,verbose=1"); - auto embree_version_major = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_MAJOR); - auto embree_version_minor = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_MINOR); - auto embree_version_patch = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_PATCH); - LUISA_INFO("Embree version: {}.{}.{}", embree_version_major, embree_version_minor, embree_version_patch); rtcSetDeviceErrorFunction( _rtc_device, [](void *, RTCError code, const char *message) { diff --git a/src/backends/fallback/fallback_shader.cpp b/src/backends/fallback/fallback_shader.cpp index a3d2a7021..31025bbf0 100644 --- a/src/backends/fallback/fallback_shader.cpp +++ b/src/backends/fallback/fallback_shader.cpp @@ -26,6 +26,7 @@ #include #include +#include #include "../common/shader_print_formatter.h" @@ -175,10 +176,13 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio Clock opt_clk; auto dce1_info = xir::dce_pass_run_on_module(xir_module); auto store_forward_info = xir::local_store_forward_pass_run_on_module(xir_module); + auto load_elim_info = xir::local_load_elimination_pass_run_on_module(xir_module); auto dce2_info = xir::dce_pass_run_on_module(xir_module); LUISA_INFO("Forwarded {} store instruction(s), " + "eliminated {} load instruction(s), " "removed {} dead instructions in {} ms.", store_forward_info.forwarded_instructions.size(), + load_elim_info.eliminated_instructions.size(), dce1_info.removed_instructions.size() + dce2_info.removed_instructions.size(), opt_clk.toc()); diff --git a/src/xir/passes/local_load_elimination.cpp b/src/xir/passes/local_load_elimination.cpp index 5b2cbb993..3f773b006 100644 --- a/src/xir/passes/local_load_elimination.cpp +++ b/src/xir/passes/local_load_elimination.cpp @@ -1,5 +1,126 @@ +#include +#include +#include #include namespace luisa::compute::xir { +namespace detail { + +[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept { + if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) { + return nullptr; + } + switch (auto inst = static_cast(pointer); inst->derived_instruction_tag()) { + case DerivedInstructionTag::ALLOCA: { + if (auto alloca_inst = static_cast(inst); alloca_inst->space() == AllocSpace::LOCAL) { + return alloca_inst; + } + return nullptr; + } + case DerivedInstructionTag::GEP: { + auto gep_inst = static_cast(inst); + return trace_pointer_base_local_alloca_inst(gep_inst->base()); + } + default: break; + } + return nullptr; +} + +static void run_local_load_elimination_on_basic_block(luisa::unordered_set &visited, + BasicBlock *block, + LocalLoadEliminationInfo &info) noexcept { + + luisa::unordered_map> variable_pointers;// maps variables to pointers + luisa::unordered_map already_loaded; // maps pointers to the earliest load instructions + luisa::unordered_map removable_loads; // maps loads to the load that can be forwarded + + auto invalidate_interfering_loads = [&](Value *ptr) noexcept -> AllocaInst * { + if (auto alloca_inst = trace_pointer_base_local_alloca_inst(ptr)) { + auto &interfering_ptrs = variable_pointers[alloca_inst]; + interfering_ptrs.emplace_back(ptr); + for (auto interfering_ptr : interfering_ptrs) { + already_loaded.erase(interfering_ptr); + } + return alloca_inst; + } + return nullptr; + }; + + // we visit the block and all of its single straight-line successors to find the earliest loads + while (visited.emplace(block).second) { + + // process the instructions in the block + for (auto &&inst : block->instructions()) { + switch (inst.derived_instruction_tag()) { + case DerivedInstructionTag::LOAD: { + auto load = static_cast(&inst); + if (auto iter = already_loaded.find(load->variable()); iter != already_loaded.end()) { + removable_loads.emplace(load, iter->second); + } else { + already_loaded[load->variable()] = load; + } + break; + } + case DerivedInstructionTag::GEP: { + // users of GEPs will handle the forwarding, so we don't need to do anything here + break; + } + default: { + for (auto op_use : inst.operand_uses()) { + invalidate_interfering_loads(op_use->value()); + } + break; + } + } + } + + // move to the next block if it is the only successor and only has a single predecessor + BasicBlock *next = nullptr; + auto successor_count = 0u; + block->traverse_successors(true, [&](BasicBlock *succ) noexcept { + successor_count++; + next = succ; + }); + if (successor_count != 1) { break; } + // check if the next block has a single predecessor + auto pred_count = 0u; + next->traverse_predecessors(false, [&](BasicBlock *) noexcept { pred_count++; }); + if (pred_count != 1) { break; } + block = next; + } + + // process the instructions + for (auto [current_load, earlier_load] : removable_loads) { + current_load->replace_all_uses_with(earlier_load); + current_load->remove_self(); + info.eliminated_instructions.emplace(current_load, earlier_load); + } } + +static void run_local_load_elimination_on_function(Function *function, LocalLoadEliminationInfo &info) noexcept { + if (auto definition = function->definition()) { + luisa::unordered_set visited; + definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept { + run_local_load_elimination_on_basic_block(visited, block, info); + }); + } +} + +}// namespace detail + +LocalLoadEliminationInfo local_load_elimination_pass_run_on_function(Function *function) noexcept { + LocalLoadEliminationInfo info; + detail::run_local_load_elimination_on_function(function, info); + return info; +} + +LocalLoadEliminationInfo local_load_elimination_pass_run_on_module(Module *module) noexcept { + LocalLoadEliminationInfo info; + for (auto &&f : module->functions()) { + detail::run_local_load_elimination_on_function(&f, info); + } + return info; +} + +}// namespace luisa::compute::xir diff --git a/src/xir/passes/local_store_forward.cpp b/src/xir/passes/local_store_forward.cpp index a65af767d..eaf275f3b 100644 --- a/src/xir/passes/local_store_forward.cpp +++ b/src/xir/passes/local_store_forward.cpp @@ -8,7 +8,7 @@ namespace luisa::compute::xir { namespace detail { -[[nodiscard]] AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept { +[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept { if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) { return nullptr; } @@ -30,7 +30,8 @@ namespace detail { // TODO: we only handle local alloca's in straight-line code for now static void run_local_store_forward_on_basic_block(luisa::unordered_set &visited, - BasicBlock *block, LocalStoreForwardInfo &info) noexcept { + BasicBlock *block, + LocalStoreForwardInfo &info) noexcept { luisa::unordered_map> variable_pointers;// maps variables to pointers luisa::unordered_map latest_stores; // maps pointers to the latest store instruction @@ -38,8 +39,9 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_set AllocaInst * { if (auto alloca_inst = trace_pointer_base_local_alloca_inst(ptr)) { - variable_pointers[alloca_inst].emplace_back(ptr); - for (auto interfering_ptr : variable_pointers[alloca_inst]) { + auto &interfering_ptrs = variable_pointers[alloca_inst]; + interfering_ptrs.emplace_back(ptr); + for (auto interfering_ptr : interfering_ptrs) { latest_stores.erase(interfering_ptr); } return alloca_inst; @@ -49,6 +51,7 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_setinstructions()) { switch (inst.derived_instruction_tag()) { @@ -93,14 +96,16 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_setreplace_all_uses_with(store->value()); load->remove_self(); info.forwarded_instructions.emplace(load, store); } } -void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept { +static void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept { if (auto definition = function->definition()) { luisa::unordered_set visited; definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept {