Skip to content

Commit

Permalink
refactor new event as template
Browse files Browse the repository at this point in the history
  • Loading branch information
theo25 committed Dec 11, 2024
1 parent 4431b20 commit 1d0356a
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 93 deletions.
107 changes: 95 additions & 12 deletions include/kllvm/codegen/ProofEvent.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
#include "kllvm/ast/AST.h"
#include "kllvm/codegen/Decision.h"
#include "kllvm/codegen/DecisionParser.h"
#include "kllvm/codegen/Options.h"
#include "kllvm/codegen/Util.h"

#include "llvm/IR/Instructions.h"

#include <fmt/format.h>

#include <map>
#include <tuple>

Expand All @@ -21,31 +24,58 @@ class proof_event {

/*
* Load the boolean flag that controls whether proof hint output is enabled or
* not, then create a branch at the end of this basic block depending on the
* result.
* not, then create a branch at the specified location depending on the
* result. The location can be before a given instruction or at the end of a
* given basic block.
*
* Returns a pair of blocks [proof enabled, merge]; the first of these is
* intended for self-contained behaviour only relevant in proof output mode,
* while the second is for the continuation of the interpreter's previous
* behaviour.
*/
template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end);
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
proof_branch(std::string const &label, llvm::Instruction *insert_before);
proof_branch(std::string const &label, Location *insert_loc);

/*
* Return the parent function of the given location.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::Function *get_parent_function(Location *loc);

/*
* Return the parent basic block of the given location.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
llvm::BasicBlock *get_parent_block(Location *loc);

/*
* If the given location is an Instruction, this method moves the instruction
* to the merge block.
* If the given location is a BasicBlock, this method simply emits a no-op
* instruction to the merge block.
* Template specializations for llvm::Instruction and llvm::BasicBlock.
*/
template <typename Location>
void fix_insert_loc(Location *loc, llvm::BasicBlock *merge_block);

/*
* Set up a standard event prelude by creating a pair of basic blocks for the
* proof output and continuation, then loading the output filename from its
* global.
* global. The location for the prelude can be before a given instruction or
* at the end of a given basic block.
*
* Returns a triple [proof enabled, merge, proof_writer]; see `proofBranch`
* and `emitGetOutputFileName`.
*/
template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end);
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
event_prelude(std::string const &label, llvm::Instruction *insert_before);
event_prelude(std::string const &label, Location *insert_loc);

/*
* Set up a check of whether a new proof hint chunk should be started. The
Expand Down Expand Up @@ -239,9 +269,9 @@ class proof_event {
[[nodiscard]] llvm::BasicBlock *pattern_matching_failure(
kore_composite_pattern const &pattern, llvm::BasicBlock *current_block);

[[nodiscard]] llvm::BasicBlock *function_exit(
uint64_t ordinal, bool is_tail, llvm::Instruction *insert_before,
llvm::BasicBlock *current_block);
template <typename Location>
[[nodiscard]] llvm::BasicBlock *
function_exit(uint64_t ordinal, bool is_tail, Location *insert_loc);

proof_event(kore_definition *definition, llvm::Module *module)
: definition_(definition)
Expand All @@ -251,4 +281,57 @@ class proof_event {

} // namespace kllvm

//===----------------------------------------------------------------------===//
// Implementation for method templates
//===----------------------------------------------------------------------===//

template <typename Location>
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
kllvm::proof_event::proof_branch(
std::string const &label, Location *insert_loc) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_loc);

auto *f = get_parent_function(insert_loc);
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

llvm::BranchInst::Create(true_block, merge_block, proof_output, insert_loc);

fix_insert_loc(insert_loc, merge_block);

return {true_block, merge_block};
}

template <typename Location>
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
kllvm::proof_event::event_prelude(
std::string const &label, Location *insert_loc) {
auto [true_block, merge_block] = proof_branch(label, insert_loc);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

template <typename Location>
llvm::BasicBlock *kllvm::proof_event::function_exit(
uint64_t ordinal, bool is_tail, Location *insert_loc) {

if (!proof_hint_instrumentation) {
return get_parent_block(insert_loc);
}

auto [true_block, merge_block, proof_writer]
= event_prelude("function_exit", insert_loc);

emit_write_function_exit(proof_writer, ordinal, is_tail, true_block);

llvm::BranchInst::Create(merge_block, true_block);

return merge_block;
}

#endif // PROOF_EVENT_H
13 changes: 6 additions & 7 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,20 +1295,19 @@ bool make_function(
if (is_apply_rule) {
current_block
= proof_event(definition, module)
.function_exit(ordinal, true, call, current_block);
.function_exit(
ordinal, true, llvm::dyn_cast<llvm::Instruction>(call));
}
} else {
if (is_apply_rule) {
current_block
= proof_event(definition, module)
.function_exit(ordinal, false, nullptr, current_block);
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
} else {
if (is_apply_rule) {
current_block
= proof_event(definition, module)
.function_exit(ordinal, false, nullptr, current_block);
current_block = proof_event(definition, module)
.function_exit(ordinal, false, current_block);
}
}
}
Expand Down
106 changes: 32 additions & 74 deletions lib/codegen/ProofEvent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,63 +372,6 @@ proof_event::emit_get_proof_chunk_size(llvm::BasicBlock *insert_at_end) {
i64_ty, proof_chunk_size_pointer, "proof_chunk_size", insert_at_end);
}

std::pair<llvm::BasicBlock *, llvm::BasicBlock *> proof_event::proof_branch(
std::string const &label, llvm::BasicBlock *insert_at_end) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_at_end);

auto *f = insert_at_end->getParent();
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

emit_no_op(merge_block);

llvm::BranchInst::Create(
true_block, merge_block, proof_output, insert_at_end);
return {true_block, merge_block};
}

std::pair<llvm::BasicBlock *, llvm::BasicBlock *> proof_event::proof_branch(
std::string const &label, llvm::Instruction *insert_before) {
auto *i1_ty = llvm::Type::getInt1Ty(ctx_);

auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty);
auto *proof_output = new llvm::LoadInst(
i1_ty, proof_output_flag, "proof_output", insert_before);

auto *f = insert_before->getParent()->getParent();
auto *true_block
= llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f);
auto *merge_block
= llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f);

llvm::BranchInst::Create(
true_block, merge_block, proof_output, insert_before);

insert_before->moveBefore(*merge_block, merge_block->begin());

return {true_block, merge_block};
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
proof_event::event_prelude(
std::string const &label, llvm::BasicBlock *insert_at_end) {
auto [true_block, merge_block] = proof_branch(label, insert_at_end);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
proof_event::event_prelude(
std::string const &label, llvm::Instruction *insert_before) {
auto [true_block, merge_block] = proof_branch(label, insert_before);
return {true_block, merge_block, emit_get_proof_trace_writer(true_block)};
}

llvm::BasicBlock *proof_event::check_for_emit_new_chunk(
llvm::BasicBlock *insert_at_end, llvm::BasicBlock *merge_block) {
auto *f = insert_at_end->getParent();
Expand Down Expand Up @@ -745,29 +688,44 @@ llvm::BasicBlock *proof_event::pattern_matching_failure(
return merge_block;
}

llvm::BasicBlock *proof_event::function_exit(
uint64_t ordinal, bool is_tail, llvm::Instruction *insert_before,
llvm::BasicBlock *current_block) {
//===----------------------------------------------------------------------===//
// Method template specializations
//===----------------------------------------------------------------------===//

if (!proof_hint_instrumentation) {
return current_block;
}
template <>
llvm::Function *kllvm::proof_event::get_parent_function<llvm::Instruction>(
llvm::Instruction *loc) {
return loc->getParent()->getParent();
}

std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *> prelude;
if (is_tail) {
assert(insert_before);
prelude = event_prelude("function_exit", insert_before);
} else {
prelude = event_prelude("function_exit", current_block);
}
template <>
llvm::Function *kllvm::proof_event::get_parent_function<llvm::BasicBlock>(
llvm::BasicBlock *loc) {
return loc->getParent();
}

auto [true_block, merge_block, proof_writer] = prelude;
template <>
llvm::BasicBlock *kllvm::proof_event::get_parent_block<llvm::Instruction>(
llvm::Instruction *loc) {
return loc->getParent();
}

emit_write_function_exit(proof_writer, ordinal, is_tail, true_block);
template <>
llvm::BasicBlock *
kllvm::proof_event::get_parent_block<llvm::BasicBlock>(llvm::BasicBlock *loc) {
return loc;
}

llvm::BranchInst::Create(merge_block, true_block);
template <>
void kllvm::proof_event::fix_insert_loc<llvm::Instruction>(
llvm::Instruction *loc, llvm::BasicBlock *merge_block) {
loc->moveBefore(*merge_block, merge_block->begin());
}

return merge_block;
template <>
void kllvm::proof_event::fix_insert_loc<llvm::BasicBlock>(
llvm::BasicBlock *loc, llvm::BasicBlock *merge_block) {
emit_no_op(merge_block);
}

} // namespace kllvm

0 comments on commit 1d0356a

Please sign in to comment.