Skip to content

Commit

Permalink
Merge branch 'master' into evaluate-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Baltoli committed Dec 4, 2023
2 parents bf43a91 + c75fe25 commit f6b8abf
Show file tree
Hide file tree
Showing 23 changed files with 6,615 additions and 490 deletions.
20 changes: 5 additions & 15 deletions bindings/c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace fs = std::filesystem;
// Internal implementation details
namespace {

template <typename OS>
char *get_c_string(OS const &);
char *get_c_string(std::string const &);

kore_pattern *kore_string_pattern_new_internal(std::string const &);

Expand Down Expand Up @@ -85,9 +84,7 @@ struct kore_symbol {
/* KOREPattern */

char *kore_pattern_dump(kore_pattern const *pat) {
auto os = std::ostringstream{};
pat->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*pat->ptr_));
}

char *kore_pattern_pretty_print(kore_pattern const *pat) {
Expand Down Expand Up @@ -320,9 +317,7 @@ kore_string_pattern_new_with_len(char const *contents, size_t len) {
/* KORESort */

char *kore_sort_dump(kore_sort const *sort) {
auto os = std::ostringstream{};
sort->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*sort->ptr_));
}

void kore_sort_free(kore_sort const *sort) {
Expand Down Expand Up @@ -372,9 +367,7 @@ void kore_symbol_free(kore_symbol const *sym) {
}

char *kore_symbol_dump(kore_symbol const *sym) {
auto os = std::ostringstream{};
sym->ptr_->print(os);
return get_c_string(os);
return get_c_string(ast_to_string(*sym->ptr_));
}

void kore_symbol_add_formal_argument(kore_symbol *sym, kore_sort const *sort) {
Expand All @@ -394,10 +387,7 @@ void kllvm_free_all_memory(void) {

namespace {

template <typename OS>
char *get_c_string(OS const &os) {
auto str = os.str();

char *get_c_string(std::string const &str) {
// Include null terminator
auto total_length = str.length() + 1;

Expand Down
5 changes: 5 additions & 0 deletions bindings/core/src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *);

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label) {
auto tag = getTagForSymbolName(label.c_str());
return getReturnSortForTag(tag);
}

std::shared_ptr<KOREPattern> make_injection(
std::shared_ptr<KOREPattern> term, std::shared_ptr<KORESort> from,
std::shared_ptr<KORESort> to) {
Expand Down
12 changes: 4 additions & 8 deletions bindings/python/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ void *constructInitialConfiguration(const KOREPattern *initial);
void bind_runtime(py::module_ &m) {
auto runtime = m.def_submodule("runtime", "K LLVM backend runtime");

// These simplifications should really be a member function on the Python
// These simplifications should really be member functions on the Python
// Pattern class, but they depend on the runtime library and so need to be
// bound as free functions in the kllvm.runtime module.

m.def(
"simplify_pattern",
[](std::shared_ptr<KOREPattern> pattern, std::shared_ptr<KORESort> sort) {
return bindings::simplify(pattern, sort);
});

m.def("simplify_pattern", bindings::simplify);
m.def("simplify_bool_pattern", bindings::simplify_to_bool);

m.def("return_sort_for_label", bindings::return_sort_for_label);

m.def("evaluate_function", bindings::evaluate_function);

// This class can't be used directly from Python; the mutability semantics
Expand Down
29 changes: 16 additions & 13 deletions include/kllvm/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ using sptr = std::shared_ptr<T>;

std::string decodeKore(std::string);

/*
* Helper function to avoid repeated call-site uses of ostringstream when we
* just want the string representation of a node, rather than to print it to a
* stream.
*/
template <typename T, typename... Args>
std::string ast_to_string(T &&node, Args &&...args) {
auto os = std::ostringstream{};
std::forward<T>(node).print(os, std::forward<Args>(args)...);
return os.str();
}

// KORESort
class KORESort : public std::enable_shared_from_this<KORESort> {
public:
Expand All @@ -62,9 +74,7 @@ static inline std::ostream &operator<<(std::ostream &out, const KORESort &s) {

struct HashSort {
size_t operator()(const kllvm::KORESort &s) const noexcept {
std::ostringstream Out;
s.print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(s));
}
};

Expand All @@ -76,9 +86,7 @@ struct EqualSortPtr {

struct HashSortPtr {
size_t operator()(kllvm::KORESort *const &s) const noexcept {
std::ostringstream Out;
s->print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(*s));
}
};

Expand Down Expand Up @@ -281,18 +289,13 @@ struct HashSymbol {

struct EqualSymbolPtr {
bool operator()(KORESymbol *const &first, KORESymbol *const &second) const {
std::ostringstream Out1, Out2;
first->print(Out1);
second->print(Out2);
return Out1.str() == Out2.str();
return ast_to_string(*first) == ast_to_string(*second);
}
};

struct HashSymbolPtr {
size_t operator()(kllvm::KORESymbol *const &s) const noexcept {
std::ostringstream Out;
s->print(Out);
return std::hash<std::string>{}(Out.str());
return std::hash<std::string>{}(ast_to_string(*s));
}
};

Expand Down
2 changes: 2 additions & 0 deletions include/kllvm/bindings/core/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label);

std::shared_ptr<kllvm::KOREPattern> make_injection(
std::shared_ptr<kllvm::KOREPattern> term,
std::shared_ptr<kllvm::KORESort> from, std::shared_ptr<kllvm::KORESort> to);
Expand Down
1 change: 1 addition & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag);
bool hook_STRING_eq(SortString, SortString);

const char *getSymbolNameForTag(uint32_t tag);
const char *getReturnSortForTag(uint32_t tag);
const char *topSort(void);

typedef struct {
Expand Down
13 changes: 3 additions & 10 deletions lib/ast/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,9 +970,7 @@ sptr<KOREPattern> KORECompositePattern::dedupeDisjuncts(void) {
flatten(this, "\\or", items);
std::set<std::string> printed;
for (sptr<KOREPattern> item : items) {
std::ostringstream Out;
item->print(Out);
if (printed.insert(Out.str()).second) {
if (printed.insert(ast_to_string(*item)).second) {
dedupedItems.push_back(item);
}
}
Expand Down Expand Up @@ -1170,10 +1168,7 @@ bool KOREVariablePattern::matches(
substitution &subst, SubsortMap const &subsorts, SymbolMap const &overloads,
sptr<KOREPattern> subject) {
if (subst[name->getName()]) {
std::ostringstream Out1, Out2;
subst[name->getName()]->print(Out1);
subject->print(Out2);
return Out1.str() == Out2.str();
return ast_to_string(*subst[name->getName()]) == ast_to_string(*subject);
} else {
subst[name->getName()] = subject;
return true;
Expand Down Expand Up @@ -1796,9 +1791,7 @@ void KOREDefinition::preprocess() {
symbol->firstTag = symbol->lastTag = instantiations.at(*symbol);
symbol->layout = layouts.at(layoutStr);
objectSymbols[symbol->firstTag] = symbol;
std::ostringstream Out;
symbol->print(Out);
allObjectSymbols[Out.str()] = symbol;
allObjectSymbols[ast_to_string(*symbol)] = symbol;
}
}
uint32_t lastTag = nextSymbol - 1;
Expand Down
26 changes: 10 additions & 16 deletions lib/codegen/CreateTerm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,9 @@ llvm::Value *CreateTerm::createHook(
std::string domain = name.substr(0, name.find('.'));
if (domain == "ARRAY") {
// array is not really hooked in llvm, it's implemented in K
std::ostringstream Out;
pattern->getConstructor()->print(Out, 0, false);
return createFunctionCall(
"eval_" + Out.str(), pattern, false, true, locationStack);
auto fn_name = fmt::format(
"eval_{}", ast_to_string(*pattern->getConstructor(), 0, false));
return createFunctionCall(fn_name, pattern, false, true, locationStack);
}
std::string hookName
= "hook_" + domain + "_" + name.substr(name.find('.') + 1);
Expand Down Expand Up @@ -900,11 +899,10 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) {

return std::make_pair(val, true);
} else {
std::ostringstream Out;
symbol->print(Out, 0, false);
auto fn_name = fmt::format("eval_{}", ast_to_string(*symbol, 0, false));
return std::make_pair(
createFunctionCall(
"eval_" + Out.str(), constructor, false, true, locationStack),
fn_name, constructor, false, true, locationStack),
true);
}
} else if (auto cat = dynamic_cast<KORECompositeSort *>(
Expand Down Expand Up @@ -1008,10 +1006,8 @@ bool makeFunction(
return false;
}
auto cat = sort->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
llvm::Type *paramType = getValueType(cat, Module);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down Expand Up @@ -1045,11 +1041,11 @@ bool makeFunction(
if (axiom->getAttributes().count("label")) {
debugName = axiom->getStringAttribute("label") + postfix;
}
std::ostringstream Out;
termSort(pattern)->print(Out);
initDebugFunction(
debugName, debugName,
getDebugFunctionType(getDebugType(returnCat, Out.str()), debugArgs),
getDebugFunctionType(
getDebugType(returnCat, ast_to_string(*termSort(pattern))),
debugArgs),
definition, applyRule);
if (tailcc) {
applyRule->setCallingConv(llvm::CallingConv::Tail);
Expand Down Expand Up @@ -1129,10 +1125,8 @@ std::string makeApplyRuleFunction(
return "";
}
auto cat = sort->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
llvm::Type *paramType = getValueType(cat, Module);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down
54 changes: 22 additions & 32 deletions lib/codegen/Decision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
#include <llvm/IR/Value.h>
#include <llvm/Support/Casting.h>

#include <fmt/format.h>

#include <iostream>
#include <limits>
#include <memory>
#include <set>
#include <type_traits>

namespace kllvm {

static std::string LAYOUTITEM_STRUCT = "layoutitem";
Expand Down Expand Up @@ -104,22 +107,19 @@ getFailPattern(DecisionCase const &_case, bool isInt) {
+ std::to_string(bitwidth) + "\")");
}
} else {
std::ostringstream symbol;
_case.getConstructor()->print(symbol);
std::ostringstream returnSort;
_case.getConstructor()->getSort()->print(returnSort);
std::string result = symbol.str() + "(";
auto result = fmt::format("{}(", ast_to_string(*_case.getConstructor()));

std::string conn = "";
for (int i = 0; i < _case.getConstructor()->getArguments().size(); i++) {
result += conn;
result += "Var'Unds'";
std::ostringstream argSort;
_case.getConstructor()->getArguments()[i]->print(argSort);
result += ":" + argSort.str();
result += fmt::format(
"{}Var'Unds':{}", conn,
ast_to_string(*_case.getConstructor()->getArguments()[i]));
conn = ",";
}
result += ")";
return std::make_pair(returnSort.str(), result);

auto return_sort = ast_to_string(*_case.getConstructor()->getSort());
return std::make_pair(return_sort, result);
}
}

Expand Down Expand Up @@ -732,18 +732,15 @@ void makeEvalOrAnywhereFunction(
auto returnSort = dynamic_cast<KORECompositeSort *>(function->getSort().get())
->getCategory(definition);
auto returnType = getParamType(returnSort, module);
std::ostringstream Out;
function->getSort()->print(Out);
auto debugReturnType = getDebugType(returnSort, Out.str());
auto debugReturnType
= getDebugType(returnSort, ast_to_string(*function->getSort()));
std::vector<llvm::Type *> args;
std::vector<llvm::Metadata *> debugArgs;
std::vector<ValueType> cats;
for (auto &sort : function->getArguments()) {
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
->getCategory(definition);
std::ostringstream Out;
sort->print(Out);
debugArgs.push_back(getDebugType(cat, Out.str()));
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand All @@ -760,9 +757,7 @@ void makeEvalOrAnywhereFunction(
}
llvm::FunctionType *funcType
= llvm::FunctionType::get(returnType, args, false);
std::ostringstream Out2;
function->print(Out2, 0, false);
std::string name = "eval_" + Out2.str();
std::string name = fmt::format("eval_{}", ast_to_string(*function, 0, false));
llvm::Function *matchFunc = getOrInsertFunction(module, name, funcType);
KORESymbolDeclaration *symbolDecl
= definition->getSymbolDeclarations().at(function->getName());
Expand Down Expand Up @@ -791,9 +786,9 @@ void makeEvalOrAnywhereFunction(
++val, ++i) {
val->setName("_" + std::to_string(i + 1));
codegen.store(std::make_pair(val->getName().str(), val->getType()), val);
std::ostringstream Out;
function->getArguments()[i]->print(Out);
initDebugParam(matchFunc, i, val->getName().str(), cats[i], Out.str());
initDebugParam(
matchFunc, i, val->getName().str(), cats[i],
ast_to_string(*function->getArguments()[i]));
}
addStuck(stuck, module, function, codegen, definition);

Expand All @@ -804,9 +799,7 @@ void abortWhenStuck(
llvm::BasicBlock *CurrentBlock, llvm::Module *Module, KORESymbol *symbol,
Decision &codegen, KOREDefinition *d) {
auto &Ctx = Module->getContext();
std::ostringstream Out;
symbol->print(Out);
symbol = d->getAllSymbols().at(Out.str());
symbol = d->getAllSymbols().at(ast_to_string(*symbol));
auto BlockType = getBlockType(Module, d, symbol);
llvm::Value *Ptr;
auto BlockPtr = llvm::PointerType::getUnqual(
Expand Down Expand Up @@ -1276,9 +1269,7 @@ void makeStepFunction(
auto argSort
= dynamic_cast<KORECompositeSort *>(res.pattern->getSort().get());
auto cat = argSort->getCategory(definition);
std::ostringstream Out;
argSort->print(Out);
debugTypes.push_back(getDebugType(cat, Out.str()));
debugTypes.push_back(getDebugType(cat, ast_to_string(*argSort)));
switch (cat.cat) {
case SortCategory::Map:
case SortCategory::RangeMap:
Expand Down Expand Up @@ -1334,9 +1325,8 @@ void makeStepFunction(
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
->getCategory(definition);
types.push_back(cat);
std::ostringstream Out;
sort->print(Out);
initDebugParam(matchFunc, i, "_" + std::to_string(i + 1), cat, Out.str());
initDebugParam(
matchFunc, i, "_" + std::to_string(i + 1), cat, ast_to_string(*sort));
}
auto header = stepFunctionHeader(
axiom->getOrdinal(), module, definition, block, stuck, args, types);
Expand Down
Loading

0 comments on commit f6b8abf

Please sign in to comment.