Skip to content

Commit

Permalink
fix local load elimination pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 12, 2025
1 parent 77fe8b7 commit f9117f6
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ FallbackShader::FallbackShader(FallbackDevice *device, const ShaderOption &optio
auto dce2_info = xir::dce_pass_run_on_module(xir_module);
LUISA_INFO("Forwarded {} store instruction(s), "
"eliminated {} load instruction(s), "
"removed {} dead instructions in {} ms.",
"removed {} dead instruction(s) in {} ms.",
store_forward_info.forwarded_instructions.size(),
load_elim_info.eliminated_instructions.size(),
dce1_info.removed_instructions.size() + dce2_info.removed_instructions.size(),
Expand Down
1 change: 1 addition & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
translators/xir2text.cpp

# passes
passes/helpers.cpp
passes/dce.cpp
passes/dom_tree.cpp
passes/outline.cpp
Expand Down
27 changes: 27 additions & 0 deletions src/xir/passes/helpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/gep.h>
#include "helpers.h"

namespace luisa::compute::xir {

AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) {
return nullptr;
}
switch (auto inst = static_cast<Instruction *>(pointer); inst->derived_instruction_tag()) {
case DerivedInstructionTag::ALLOCA: {
if (auto alloca_inst = static_cast<AllocaInst *>(inst); alloca_inst->space() == AllocSpace::LOCAL) {
return alloca_inst;
}
return nullptr;
}
case DerivedInstructionTag::GEP: {
auto gep_inst = static_cast<GEPInst *>(inst);
return trace_pointer_base_local_alloca_inst(gep_inst->base());
}
default: break;
}
return nullptr;
}

}// namespace luisa::compute::xir
12 changes: 12 additions & 0 deletions src/xir/passes/helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <luisa/core/dll_export.h>

namespace luisa::compute::xir {

class AllocaInst;
class Value;

[[nodiscard]] LC_XIR_API AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept;

}// namespace luisa::compute::xir
25 changes: 4 additions & 21 deletions src/xir/passes/local_load_elimination.cpp
Original file line number Diff line number Diff line change
@@ -1,32 +1,14 @@
#include <luisa/xir/function.h>
#include <luisa/xir/module.h>
#include <luisa/xir/builder.h>

#include "helpers.h"
#include <luisa/xir/passes/local_load_elimination.h>

namespace luisa::compute::xir {

namespace detail {

[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) {
return nullptr;
}
switch (auto inst = static_cast<Instruction *>(pointer); inst->derived_instruction_tag()) {
case DerivedInstructionTag::ALLOCA: {
if (auto alloca_inst = static_cast<AllocaInst *>(inst); alloca_inst->space() == AllocSpace::LOCAL) {
return alloca_inst;
}
return nullptr;
}
case DerivedInstructionTag::GEP: {
auto gep_inst = static_cast<GEPInst *>(inst);
return trace_pointer_base_local_alloca_inst(gep_inst->base());
}
default: break;
}
return nullptr;
}

static void run_local_load_elimination_on_basic_block(luisa::unordered_set<BasicBlock *> &visited,
BasicBlock *block,
LocalLoadEliminationInfo &info) noexcept {
Expand Down Expand Up @@ -57,7 +39,8 @@ static void run_local_load_elimination_on_basic_block(luisa::unordered_set<Basic
auto load = static_cast<LoadInst *>(&inst);
if (auto iter = already_loaded.find(load->variable()); iter != already_loaded.end()) {
removable_loads.emplace(load, iter->second);
} else {
} else if (auto alloca_inst = trace_pointer_base_local_alloca_inst(load->variable())) {
variable_pointers[alloca_inst].emplace_back(load->variable());
already_loaded[load->variable()] = load;
}
break;
Expand Down
24 changes: 2 additions & 22 deletions src/xir/passes/local_store_forward.cpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,15 @@
#include "luisa/core/logging.h"

#include <luisa/core/logging.h>
#include <luisa/xir/function.h>
#include <luisa/xir/module.h>
#include <luisa/xir/builder.h>

#include "helpers.h"
#include <luisa/xir/passes/local_store_forward.h>

namespace luisa::compute::xir {

namespace detail {

[[nodiscard]] static AllocaInst *trace_pointer_base_local_alloca_inst(Value *pointer) noexcept {
if (pointer == nullptr || pointer->derived_value_tag() != DerivedValueTag::INSTRUCTION) {
return nullptr;
}
switch (auto inst = static_cast<Instruction *>(pointer); inst->derived_instruction_tag()) {
case DerivedInstructionTag::ALLOCA: {
if (auto alloca_inst = static_cast<AllocaInst *>(inst); alloca_inst->space() == AllocSpace::LOCAL) {
return alloca_inst;
}
return nullptr;
}
case DerivedInstructionTag::GEP: {
auto gep_inst = static_cast<GEPInst *>(inst);
return trace_pointer_base_local_alloca_inst(gep_inst->base());
}
default: break;
}
return nullptr;
}

// TODO: we only handle local alloca's in straight-line code for now
static void run_local_store_forward_on_basic_block(luisa::unordered_set<BasicBlock *> &visited,
BasicBlock *block,
Expand Down

0 comments on commit f9117f6

Please sign in to comment.