Skip to content

Commit

Permalink
add straightline load-to-load forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 10, 2025
1 parent aa4e9d9 commit 38c9bf4
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 13 deletions.
5 changes: 5 additions & 0 deletions include/luisa/xir/passes/local_load_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ class LoadInst;
class Function;
class Module;

// This pass implements a simple local load elimination optimization.
// For each load instruction, if a recent load instruction with the same
// variable has happened without possible intervening stores, the load
// instruction can be replaced with the recent load instruction.

struct LocalLoadEliminationInfo {
luisa::unordered_map<LoadInst *, LoadInst *> eliminated_instructions;
};
Expand Down
3 changes: 0 additions & 3 deletions include/luisa/xir/passes/local_store_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class Module;
// This pass is used to forward stores to loads for scalar variables
// within straight-line basic blocks. It is a simple peephole optimization
// that can be used to reduce the number of memory operations.
// Note: this pass does not remove the original store instructions.
// It only forwards the values to the loads. To remove the original
// store instructions, a DCE pass should be used after this pass.

struct LocalStoreForwardInfo {
luisa::unordered_map<LoadInst *, StoreInst *> forwarded_instructions;
Expand Down
4 changes: 0 additions & 4 deletions src/backends/fallback/fallback_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ FallbackDevice::FallbackDevice(Context &&ctx) noexcept

// embree
_rtc_device = rtcNewDevice("frequency_level=simd128,isa=avx2,verbose=1");
auto embree_version_major = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_MAJOR);
auto embree_version_minor = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_MINOR);
auto embree_version_patch = rtcGetDeviceProperty(_rtc_device, RTC_DEVICE_PROPERTY_VERSION_PATCH);
LUISA_INFO("Embree version: {}.{}.{}", embree_version_major, embree_version_minor, embree_version_patch);
rtcSetDeviceErrorFunction(
_rtc_device,
[](void *, RTCError code, const char *message) {
Expand Down
4 changes: 4 additions & 0 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <luisa/xir/passes/dce.h>
#include <luisa/xir/passes/local_store_forward.h>
#include <luisa/xir/passes/local_load_elimination.h>

#include "../common/shader_print_formatter.h"

Expand Down Expand Up @@ -175,10 +176,13 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
Clock opt_clk;
auto dce1_info = xir::dce_pass_run_on_module(xir_module);
auto store_forward_info = xir::local_store_forward_pass_run_on_module(xir_module);
auto load_elim_info = xir::local_load_elimination_pass_run_on_module(xir_module);
auto dce2_info = xir::dce_pass_run_on_module(xir_module);
LUISA_INFO("Forwarded {} store instruction(s), "
"eliminated {} load instruction(s), "
"removed {} dead instructions in {} ms.",
store_forward_info.forwarded_instructions.size(),
load_elim_info.eliminated_instructions.size(),
dce1_info.removed_instructions.size() + dce2_info.removed_instructions.size(),
opt_clk.toc());

Expand Down
121 changes: 121 additions & 0 deletions src/xir/passes/local_load_elimination.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,126 @@
#include <luisa/xir/function.h>
#include <luisa/xir/module.h>
#include <luisa/xir/builder.h>
#include <luisa/xir/passes/local_load_elimination.h>

namespace luisa::compute::xir {

namespace detail {

[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) {
return nullptr;
}
switch (auto inst = static_cast<Instruction *>(pointer); inst->derived_instruction_tag()) {
case DerivedInstructionTag::ALLOCA: {
if (auto alloca_inst = static_cast<AllocaInst *>(inst); alloca_inst->space() == AllocSpace::LOCAL) {
return alloca_inst;
}
return nullptr;
}
case DerivedInstructionTag::GEP: {
auto gep_inst = static_cast<GEPInst *>(inst);
return trace_pointer_base_local_alloca_inst(gep_inst->base());
}
default: break;
}
return nullptr;
}

static void run_local_load_elimination_on_basic_block(luisa::unordered_set<BasicBlock *> &visited,
BasicBlock *block,
LocalLoadEliminationInfo &info) noexcept {

luisa::unordered_map<AllocaInst *, luisa::vector<Value *>> variable_pointers;// maps variables to pointers
luisa::unordered_map<Value *, LoadInst *> already_loaded; // maps pointers to the earliest load instructions
luisa::unordered_map<LoadInst *, LoadInst *> removable_loads; // maps loads to the load that can be forwarded

auto invalidate_interfering_loads = [&](Value *ptr) noexcept -> AllocaInst * {
if (auto alloca_inst = trace_pointer_base_local_alloca_inst(ptr)) {
auto &interfering_ptrs = variable_pointers[alloca_inst];
interfering_ptrs.emplace_back(ptr);
for (auto interfering_ptr : interfering_ptrs) {
already_loaded.erase(interfering_ptr);
}
return alloca_inst;
}
return nullptr;
};

// we visit the block and all of its single straight-line successors to find the earliest loads
while (visited.emplace(block).second) {

// process the instructions in the block
for (auto &&inst : block->instructions()) {
switch (inst.derived_instruction_tag()) {
case DerivedInstructionTag::LOAD: {
auto load = static_cast<LoadInst *>(&inst);
if (auto iter = already_loaded.find(load->variable()); iter != already_loaded.end()) {
removable_loads.emplace(load, iter->second);
} else {
already_loaded[load->variable()] = load;
}
break;
}
case DerivedInstructionTag::GEP: {
// users of GEPs will handle the forwarding, so we don't need to do anything here
break;
}
default: {
for (auto op_use : inst.operand_uses()) {
invalidate_interfering_loads(op_use->value());
}
break;
}
}
}

// move to the next block if it is the only successor and only has a single predecessor
BasicBlock *next = nullptr;
auto successor_count = 0u;
block->traverse_successors(true, [&](BasicBlock *succ) noexcept {
successor_count++;
next = succ;
});
if (successor_count != 1) { break; }
// check if the next block has a single predecessor
auto pred_count = 0u;
next->traverse_predecessors(false, [&](BasicBlock *) noexcept { pred_count++; });
if (pred_count != 1) { break; }
block = next;
}

// process the instructions
for (auto [current_load, earlier_load] : removable_loads) {
current_load->replace_all_uses_with(earlier_load);
current_load->remove_self();
info.eliminated_instructions.emplace(current_load, earlier_load);
}
}

static void run_local_load_elimination_on_function(Function *function, LocalLoadEliminationInfo &info) noexcept {
if (auto definition = function->definition()) {
luisa::unordered_set<BasicBlock *> visited;
definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept {
run_local_load_elimination_on_basic_block(visited, block, info);
});
}
}

}// namespace detail

LocalLoadEliminationInfo local_load_elimination_pass_run_on_function(Function *function) noexcept {
LocalLoadEliminationInfo info;
detail::run_local_load_elimination_on_function(function, info);
return info;
}

LocalLoadEliminationInfo local_load_elimination_pass_run_on_module(Module *module) noexcept {
LocalLoadEliminationInfo info;
for (auto &&f : module->functions()) {
detail::run_local_load_elimination_on_function(&f, info);
}
return info;
}

}// namespace luisa::compute::xir
17 changes: 11 additions & 6 deletions src/xir/passes/local_store_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace luisa::compute::xir {

namespace detail {

[[nodiscard]] AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) {
return nullptr;
}
Expand All @@ -30,16 +30,18 @@ namespace detail {

// TODO: we only handle local alloca's in straight-line code for now
static void run_local_store_forward_on_basic_block(luisa::unordered_set<BasicBlock *> &visited,
BasicBlock *block, LocalStoreForwardInfo &info) noexcept {
BasicBlock *block,
LocalStoreForwardInfo &info) noexcept {

luisa::unordered_map<AllocaInst *, luisa::vector<Value *>> variable_pointers;// maps variables to pointers
luisa::unordered_map<Value *, StoreInst *> latest_stores; // maps pointers to the latest store instruction
luisa::unordered_map<LoadInst *, StoreInst *> removable_loads; // maps loads to the store that can be forwarded

auto invalidate_interfering_stores = [&](Value *ptr) noexcept -> AllocaInst * {
if (auto alloca_inst = trace_pointer_base_local_alloca_inst(ptr)) {
variable_pointers[alloca_inst].emplace_back(ptr);
for (auto interfering_ptr : variable_pointers[alloca_inst]) {
auto &interfering_ptrs = variable_pointers[alloca_inst];
interfering_ptrs.emplace_back(ptr);
for (auto interfering_ptr : interfering_ptrs) {
latest_stores.erase(interfering_ptr);
}
return alloca_inst;
Expand All @@ -49,6 +51,7 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_set<BasicBlo

// we visit the block and all of its single straight-line successors
while (visited.emplace(block).second) {

// process the instructions in the block
for (auto &&inst : block->instructions()) {
switch (inst.derived_instruction_tag()) {
Expand Down Expand Up @@ -93,14 +96,16 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_set<BasicBlo
if (pred_count != 1) { break; }
block = next;
}
for (auto &&[load, store] : removable_loads) {

// perform the forwarding
for (auto [load, store] : removable_loads) {
load->replace_all_uses_with(store->value());
load->remove_self();
info.forwarded_instructions.emplace(load, store);
}
}

void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept {
static void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept {
if (auto definition = function->definition()) {
luisa::unordered_set<BasicBlock *> visited;
definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept {
Expand Down

0 comments on commit 38c9bf4

Please sign in to comment.