diff --git a/include/kllvm/codegen/ProofEvent.h b/include/kllvm/codegen/ProofEvent.h index d387b4854..2aa7b8468 100644 --- a/include/kllvm/codegen/ProofEvent.h +++ b/include/kllvm/codegen/ProofEvent.h @@ -4,10 +4,13 @@ #include "kllvm/ast/AST.h" #include "kllvm/codegen/Decision.h" #include "kllvm/codegen/DecisionParser.h" +#include "kllvm/codegen/Options.h" #include "kllvm/codegen/Util.h" #include "llvm/IR/Instructions.h" +#include + #include #include @@ -21,31 +24,58 @@ class proof_event { /* * Load the boolean flag that controls whether proof hint output is enabled or - * not, then create a branch at the end of this basic block depending on the - * result. + * not, then create a branch at the specified location depending on the + * result. The location can be before a given instruction or at the end of a + * given basic block. * * Returns a pair of blocks [proof enabled, merge]; the first of these is * intended for self-contained behaviour only relevant in proof output mode, * while the second is for the continuation of the interpreter's previous * behaviour. */ + template std::pair - proof_branch(std::string const &label, llvm::BasicBlock *insert_at_end); - std::pair - proof_branch(std::string const &label, llvm::Instruction *insert_before); + proof_branch(std::string const &label, Location *insert_loc); + + /* + * Return the parent function of the given location. + + * Template specializations for llvm::Instruction and llvm::BasicBlock. + */ + template + llvm::Function *get_parent_function(Location *loc); + + /* + * Return the parent basic block of the given location. + + * Template specializations for llvm::Instruction and llvm::BasicBlock. + */ + template + llvm::BasicBlock *get_parent_block(Location *loc); + + /* + * If the given location is an Instruction, this method moves the instruction + * to the merge block. + * If the given location is a BasicBlock, this method simply emits a no-op + * instruction to the merge block. + + * Template specializations for llvm::Instruction and llvm::BasicBlock. + */ + template + void fix_insert_loc(Location *loc, llvm::BasicBlock *merge_block); /* * Set up a standard event prelude by creating a pair of basic blocks for the * proof output and continuation, then loading the output filename from its - * global. + * global. The location for the prelude can be before a given instruction or + * at the end of a given basic block. * * Returns a triple [proof enabled, merge, proof_writer]; see `proofBranch` * and `emitGetOutputFileName`. */ + template std::tuple - event_prelude(std::string const &label, llvm::BasicBlock *insert_at_end); - std::tuple - event_prelude(std::string const &label, llvm::Instruction *insert_before); + event_prelude(std::string const &label, Location *insert_loc); /* * Set up a check of whether a new proof hint chunk should be started. The @@ -239,9 +269,9 @@ class proof_event { [[nodiscard]] llvm::BasicBlock *pattern_matching_failure( kore_composite_pattern const &pattern, llvm::BasicBlock *current_block); - [[nodiscard]] llvm::BasicBlock *function_exit( - uint64_t ordinal, bool is_tail, llvm::Instruction *insert_before, - llvm::BasicBlock *current_block); + template + [[nodiscard]] llvm::BasicBlock * + function_exit(uint64_t ordinal, bool is_tail, Location *insert_loc); proof_event(kore_definition *definition, llvm::Module *module) : definition_(definition) @@ -251,4 +281,57 @@ class proof_event { } // namespace kllvm +//===----------------------------------------------------------------------===// +// Implementation for method templates +//===----------------------------------------------------------------------===// + +template +std::pair +kllvm::proof_event::proof_branch( + std::string const &label, Location *insert_loc) { + auto *i1_ty = llvm::Type::getInt1Ty(ctx_); + + auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty); + auto *proof_output = new llvm::LoadInst( + i1_ty, proof_output_flag, "proof_output", insert_loc); + + auto *f = get_parent_function(insert_loc); + auto *true_block + = llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f); + auto *merge_block + = llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f); + + llvm::BranchInst::Create(true_block, merge_block, proof_output, insert_loc); + + fix_insert_loc(insert_loc, merge_block); + + return {true_block, merge_block}; +} + +template +std::tuple +kllvm::proof_event::event_prelude( + std::string const &label, Location *insert_loc) { + auto [true_block, merge_block] = proof_branch(label, insert_loc); + return {true_block, merge_block, emit_get_proof_trace_writer(true_block)}; +} + +template +llvm::BasicBlock *kllvm::proof_event::function_exit( + uint64_t ordinal, bool is_tail, Location *insert_loc) { + + if (!proof_hint_instrumentation) { + return get_parent_block(insert_loc); + } + + auto [true_block, merge_block, proof_writer] + = event_prelude("function_exit", insert_loc); + + emit_write_function_exit(proof_writer, ordinal, is_tail, true_block); + + llvm::BranchInst::Create(merge_block, true_block); + + return merge_block; +} + #endif // PROOF_EVENT_H diff --git a/lib/codegen/CreateTerm.cpp b/lib/codegen/CreateTerm.cpp index c1856b5da..1e03690cf 100644 --- a/lib/codegen/CreateTerm.cpp +++ b/lib/codegen/CreateTerm.cpp @@ -1295,20 +1295,19 @@ bool make_function( if (is_apply_rule) { current_block = proof_event(definition, module) - .function_exit(ordinal, true, call, current_block); + .function_exit( + ordinal, true, llvm::dyn_cast(call)); } } else { if (is_apply_rule) { - current_block - = proof_event(definition, module) - .function_exit(ordinal, false, nullptr, current_block); + current_block = proof_event(definition, module) + .function_exit(ordinal, false, current_block); } } } else { if (is_apply_rule) { - current_block - = proof_event(definition, module) - .function_exit(ordinal, false, nullptr, current_block); + current_block = proof_event(definition, module) + .function_exit(ordinal, false, current_block); } } } diff --git a/lib/codegen/ProofEvent.cpp b/lib/codegen/ProofEvent.cpp index 3ed817730..62fafabae 100644 --- a/lib/codegen/ProofEvent.cpp +++ b/lib/codegen/ProofEvent.cpp @@ -372,63 +372,6 @@ proof_event::emit_get_proof_chunk_size(llvm::BasicBlock *insert_at_end) { i64_ty, proof_chunk_size_pointer, "proof_chunk_size", insert_at_end); } -std::pair proof_event::proof_branch( - std::string const &label, llvm::BasicBlock *insert_at_end) { - auto *i1_ty = llvm::Type::getInt1Ty(ctx_); - - auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty); - auto *proof_output = new llvm::LoadInst( - i1_ty, proof_output_flag, "proof_output", insert_at_end); - - auto *f = insert_at_end->getParent(); - auto *true_block - = llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f); - auto *merge_block - = llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f); - - emit_no_op(merge_block); - - llvm::BranchInst::Create( - true_block, merge_block, proof_output, insert_at_end); - return {true_block, merge_block}; -} - -std::pair proof_event::proof_branch( - std::string const &label, llvm::Instruction *insert_before) { - auto *i1_ty = llvm::Type::getInt1Ty(ctx_); - - auto *proof_output_flag = module_->getOrInsertGlobal("proof_output", i1_ty); - auto *proof_output = new llvm::LoadInst( - i1_ty, proof_output_flag, "proof_output", insert_before); - - auto *f = insert_before->getParent()->getParent(); - auto *true_block - = llvm::BasicBlock::Create(ctx_, fmt::format("if_{}", label), f); - auto *merge_block - = llvm::BasicBlock::Create(ctx_, fmt::format("tail_{}", label), f); - - llvm::BranchInst::Create( - true_block, merge_block, proof_output, insert_before); - - insert_before->moveBefore(*merge_block, merge_block->begin()); - - return {true_block, merge_block}; -} - -std::tuple -proof_event::event_prelude( - std::string const &label, llvm::BasicBlock *insert_at_end) { - auto [true_block, merge_block] = proof_branch(label, insert_at_end); - return {true_block, merge_block, emit_get_proof_trace_writer(true_block)}; -} - -std::tuple -proof_event::event_prelude( - std::string const &label, llvm::Instruction *insert_before) { - auto [true_block, merge_block] = proof_branch(label, insert_before); - return {true_block, merge_block, emit_get_proof_trace_writer(true_block)}; -} - llvm::BasicBlock *proof_event::check_for_emit_new_chunk( llvm::BasicBlock *insert_at_end, llvm::BasicBlock *merge_block) { auto *f = insert_at_end->getParent(); @@ -745,29 +688,44 @@ llvm::BasicBlock *proof_event::pattern_matching_failure( return merge_block; } -llvm::BasicBlock *proof_event::function_exit( - uint64_t ordinal, bool is_tail, llvm::Instruction *insert_before, - llvm::BasicBlock *current_block) { +//===----------------------------------------------------------------------===// +// Method template specializations +//===----------------------------------------------------------------------===// - if (!proof_hint_instrumentation) { - return current_block; - } +template <> +llvm::Function *kllvm::proof_event::get_parent_function( + llvm::Instruction *loc) { + return loc->getParent()->getParent(); +} - std::tuple prelude; - if (is_tail) { - assert(insert_before); - prelude = event_prelude("function_exit", insert_before); - } else { - prelude = event_prelude("function_exit", current_block); - } +template <> +llvm::Function *kllvm::proof_event::get_parent_function( + llvm::BasicBlock *loc) { + return loc->getParent(); +} - auto [true_block, merge_block, proof_writer] = prelude; +template <> +llvm::BasicBlock *kllvm::proof_event::get_parent_block( + llvm::Instruction *loc) { + return loc->getParent(); +} - emit_write_function_exit(proof_writer, ordinal, is_tail, true_block); +template <> +llvm::BasicBlock * +kllvm::proof_event::get_parent_block(llvm::BasicBlock *loc) { + return loc; +} - llvm::BranchInst::Create(merge_block, true_block); +template <> +void kllvm::proof_event::fix_insert_loc( + llvm::Instruction *loc, llvm::BasicBlock *merge_block) { + loc->moveBefore(*merge_block, merge_block->begin()); +} - return merge_block; +template <> +void kllvm::proof_event::fix_insert_loc( + llvm::BasicBlock *loc, llvm::BasicBlock *merge_block) { + emit_no_op(merge_block); } } // namespace kllvm