Skip to content

Commit f6b8abf

Browse files
committed
Merge branch 'master' into evaluate-2
2 parents bf43a91 + c75fe25 commit f6b8abf

File tree

23 files changed

+6615
-490
lines changed

23 files changed

+6615
-490
lines changed

bindings/c/lib.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ namespace fs = std::filesystem;
2525
// Internal implementation details
2626
namespace {
2727

28-
template <typename OS>
29-
char *get_c_string(OS const &);
28+
char *get_c_string(std::string const &);
3029

3130
kore_pattern *kore_string_pattern_new_internal(std::string const &);
3231

@@ -85,9 +84,7 @@ struct kore_symbol {
8584
/* KOREPattern */
8685

8786
char *kore_pattern_dump(kore_pattern const *pat) {
88-
auto os = std::ostringstream{};
89-
pat->ptr_->print(os);
90-
return get_c_string(os);
87+
return get_c_string(ast_to_string(*pat->ptr_));
9188
}
9289

9390
char *kore_pattern_pretty_print(kore_pattern const *pat) {
@@ -320,9 +317,7 @@ kore_string_pattern_new_with_len(char const *contents, size_t len) {
320317
/* KORESort */
321318

322319
char *kore_sort_dump(kore_sort const *sort) {
323-
auto os = std::ostringstream{};
324-
sort->ptr_->print(os);
325-
return get_c_string(os);
320+
return get_c_string(ast_to_string(*sort->ptr_));
326321
}
327322

328323
void kore_sort_free(kore_sort const *sort) {
@@ -372,9 +367,7 @@ void kore_symbol_free(kore_symbol const *sym) {
372367
}
373368

374369
char *kore_symbol_dump(kore_symbol const *sym) {
375-
auto os = std::ostringstream{};
376-
sym->ptr_->print(os);
377-
return get_c_string(os);
370+
return get_c_string(ast_to_string(*sym->ptr_));
378371
}
379372

380373
void kore_symbol_add_formal_argument(kore_symbol *sym, kore_sort const *sort) {
@@ -394,10 +387,7 @@ void kllvm_free_all_memory(void) {
394387

395388
namespace {
396389

397-
template <typename OS>
398-
char *get_c_string(OS const &os) {
399-
auto str = os.str();
400-
390+
char *get_c_string(std::string const &str) {
401391
// Include null terminator
402392
auto total_length = str.length() + 1;
403393

bindings/core/src/core.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *);
1212

1313
namespace kllvm::bindings {
1414

15+
std::string return_sort_for_label(std::string const &label) {
16+
auto tag = getTagForSymbolName(label.c_str());
17+
return getReturnSortForTag(tag);
18+
}
19+
1520
std::shared_ptr<KOREPattern> make_injection(
1621
std::shared_ptr<KOREPattern> term, std::shared_ptr<KORESort> from,
1722
std::shared_ptr<KORESort> to) {

bindings/python/runtime.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,14 @@ void *constructInitialConfiguration(const KOREPattern *initial);
4848
void bind_runtime(py::module_ &m) {
4949
auto runtime = m.def_submodule("runtime", "K LLVM backend runtime");
5050

51-
// These simplifications should really be a member function on the Python
51+
// These simplifications should really be member functions on the Python
5252
// Pattern class, but they depend on the runtime library and so need to be
5353
// bound as free functions in the kllvm.runtime module.
54-
55-
m.def(
56-
"simplify_pattern",
57-
[](std::shared_ptr<KOREPattern> pattern, std::shared_ptr<KORESort> sort) {
58-
return bindings::simplify(pattern, sort);
59-
});
60-
54+
m.def("simplify_pattern", bindings::simplify);
6155
m.def("simplify_bool_pattern", bindings::simplify_to_bool);
6256

57+
m.def("return_sort_for_label", bindings::return_sort_for_label);
58+
6359
m.def("evaluate_function", bindings::evaluate_function);
6460

6561
// This class can't be used directly from Python; the mutability semantics

include/kllvm/ast/AST.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ using sptr = std::shared_ptr<T>;
3636

3737
std::string decodeKore(std::string);
3838

39+
/*
40+
* Helper function to avoid repeated call-site uses of ostringstream when we
41+
* just want the string representation of a node, rather than to print it to a
42+
* stream.
43+
*/
44+
template <typename T, typename... Args>
45+
std::string ast_to_string(T &&node, Args &&...args) {
46+
auto os = std::ostringstream{};
47+
std::forward<T>(node).print(os, std::forward<Args>(args)...);
48+
return os.str();
49+
}
50+
3951
// KORESort
4052
class KORESort : public std::enable_shared_from_this<KORESort> {
4153
public:
@@ -62,9 +74,7 @@ static inline std::ostream &operator<<(std::ostream &out, const KORESort &s) {
6274

6375
struct HashSort {
6476
size_t operator()(const kllvm::KORESort &s) const noexcept {
65-
std::ostringstream Out;
66-
s.print(Out);
67-
return std::hash<std::string>{}(Out.str());
77+
return std::hash<std::string>{}(ast_to_string(s));
6878
}
6979
};
7080

@@ -76,9 +86,7 @@ struct EqualSortPtr {
7686

7787
struct HashSortPtr {
7888
size_t operator()(kllvm::KORESort *const &s) const noexcept {
79-
std::ostringstream Out;
80-
s->print(Out);
81-
return std::hash<std::string>{}(Out.str());
89+
return std::hash<std::string>{}(ast_to_string(*s));
8290
}
8391
};
8492

@@ -281,18 +289,13 @@ struct HashSymbol {
281289

282290
struct EqualSymbolPtr {
283291
bool operator()(KORESymbol *const &first, KORESymbol *const &second) const {
284-
std::ostringstream Out1, Out2;
285-
first->print(Out1);
286-
second->print(Out2);
287-
return Out1.str() == Out2.str();
292+
return ast_to_string(*first) == ast_to_string(*second);
288293
}
289294
};
290295

291296
struct HashSymbolPtr {
292297
size_t operator()(kllvm::KORESymbol *const &s) const noexcept {
293-
std::ostringstream Out;
294-
s->print(Out);
295-
return std::hash<std::string>{}(Out.str());
298+
return std::hash<std::string>{}(ast_to_string(*s));
296299
}
297300
};
298301

include/kllvm/bindings/core/core.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace kllvm::bindings {
1414

15+
std::string return_sort_for_label(std::string const &label);
16+
1517
std::shared_ptr<kllvm::KOREPattern> make_injection(
1618
std::shared_ptr<kllvm::KOREPattern> term,
1719
std::shared_ptr<kllvm::KORESort> from, std::shared_ptr<kllvm::KORESort> to);

include/runtime/header.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag);
350350
bool hook_STRING_eq(SortString, SortString);
351351

352352
const char *getSymbolNameForTag(uint32_t tag);
353+
const char *getReturnSortForTag(uint32_t tag);
353354
const char *topSort(void);
354355

355356
typedef struct {

lib/ast/AST.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -970,9 +970,7 @@ sptr<KOREPattern> KORECompositePattern::dedupeDisjuncts(void) {
970970
flatten(this, "\\or", items);
971971
std::set<std::string> printed;
972972
for (sptr<KOREPattern> item : items) {
973-
std::ostringstream Out;
974-
item->print(Out);
975-
if (printed.insert(Out.str()).second) {
973+
if (printed.insert(ast_to_string(*item)).second) {
976974
dedupedItems.push_back(item);
977975
}
978976
}
@@ -1170,10 +1168,7 @@ bool KOREVariablePattern::matches(
11701168
substitution &subst, SubsortMap const &subsorts, SymbolMap const &overloads,
11711169
sptr<KOREPattern> subject) {
11721170
if (subst[name->getName()]) {
1173-
std::ostringstream Out1, Out2;
1174-
subst[name->getName()]->print(Out1);
1175-
subject->print(Out2);
1176-
return Out1.str() == Out2.str();
1171+
return ast_to_string(*subst[name->getName()]) == ast_to_string(*subject);
11771172
} else {
11781173
subst[name->getName()] = subject;
11791174
return true;
@@ -1796,9 +1791,7 @@ void KOREDefinition::preprocess() {
17961791
symbol->firstTag = symbol->lastTag = instantiations.at(*symbol);
17971792
symbol->layout = layouts.at(layoutStr);
17981793
objectSymbols[symbol->firstTag] = symbol;
1799-
std::ostringstream Out;
1800-
symbol->print(Out);
1801-
allObjectSymbols[Out.str()] = symbol;
1794+
allObjectSymbols[ast_to_string(*symbol)] = symbol;
18021795
}
18031796
}
18041797
uint32_t lastTag = nextSymbol - 1;

lib/codegen/CreateTerm.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,9 @@ llvm::Value *CreateTerm::createHook(
652652
std::string domain = name.substr(0, name.find('.'));
653653
if (domain == "ARRAY") {
654654
// array is not really hooked in llvm, it's implemented in K
655-
std::ostringstream Out;
656-
pattern->getConstructor()->print(Out, 0, false);
657-
return createFunctionCall(
658-
"eval_" + Out.str(), pattern, false, true, locationStack);
655+
auto fn_name = fmt::format(
656+
"eval_{}", ast_to_string(*pattern->getConstructor(), 0, false));
657+
return createFunctionCall(fn_name, pattern, false, true, locationStack);
659658
}
660659
std::string hookName
661660
= "hook_" + domain + "_" + name.substr(name.find('.') + 1);
@@ -900,11 +899,10 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) {
900899

901900
return std::make_pair(val, true);
902901
} else {
903-
std::ostringstream Out;
904-
symbol->print(Out, 0, false);
902+
auto fn_name = fmt::format("eval_{}", ast_to_string(*symbol, 0, false));
905903
return std::make_pair(
906904
createFunctionCall(
907-
"eval_" + Out.str(), constructor, false, true, locationStack),
905+
fn_name, constructor, false, true, locationStack),
908906
true);
909907
}
910908
} else if (auto cat = dynamic_cast<KORECompositeSort *>(
@@ -1008,10 +1006,8 @@ bool makeFunction(
10081006
return false;
10091007
}
10101008
auto cat = sort->getCategory(definition);
1011-
std::ostringstream Out;
1012-
sort->print(Out);
10131009
llvm::Type *paramType = getValueType(cat, Module);
1014-
debugArgs.push_back(getDebugType(cat, Out.str()));
1010+
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
10151011
switch (cat.cat) {
10161012
case SortCategory::Map:
10171013
case SortCategory::RangeMap:
@@ -1045,11 +1041,11 @@ bool makeFunction(
10451041
if (axiom->getAttributes().count("label")) {
10461042
debugName = axiom->getStringAttribute("label") + postfix;
10471043
}
1048-
std::ostringstream Out;
1049-
termSort(pattern)->print(Out);
10501044
initDebugFunction(
10511045
debugName, debugName,
1052-
getDebugFunctionType(getDebugType(returnCat, Out.str()), debugArgs),
1046+
getDebugFunctionType(
1047+
getDebugType(returnCat, ast_to_string(*termSort(pattern))),
1048+
debugArgs),
10531049
definition, applyRule);
10541050
if (tailcc) {
10551051
applyRule->setCallingConv(llvm::CallingConv::Tail);
@@ -1129,10 +1125,8 @@ std::string makeApplyRuleFunction(
11291125
return "";
11301126
}
11311127
auto cat = sort->getCategory(definition);
1132-
std::ostringstream Out;
1133-
sort->print(Out);
11341128
llvm::Type *paramType = getValueType(cat, Module);
1135-
debugArgs.push_back(getDebugType(cat, Out.str()));
1129+
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
11361130
switch (cat.cat) {
11371131
case SortCategory::Map:
11381132
case SortCategory::RangeMap:

lib/codegen/Decision.cpp

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@
2929
#include <llvm/IR/Value.h>
3030
#include <llvm/Support/Casting.h>
3131

32+
#include <fmt/format.h>
33+
3234
#include <iostream>
3335
#include <limits>
3436
#include <memory>
3537
#include <set>
3638
#include <type_traits>
39+
3740
namespace kllvm {
3841

3942
static std::string LAYOUTITEM_STRUCT = "layoutitem";
@@ -104,22 +107,19 @@ getFailPattern(DecisionCase const &_case, bool isInt) {
104107
+ std::to_string(bitwidth) + "\")");
105108
}
106109
} else {
107-
std::ostringstream symbol;
108-
_case.getConstructor()->print(symbol);
109-
std::ostringstream returnSort;
110-
_case.getConstructor()->getSort()->print(returnSort);
111-
std::string result = symbol.str() + "(";
110+
auto result = fmt::format("{}(", ast_to_string(*_case.getConstructor()));
111+
112112
std::string conn = "";
113113
for (int i = 0; i < _case.getConstructor()->getArguments().size(); i++) {
114-
result += conn;
115-
result += "Var'Unds'";
116-
std::ostringstream argSort;
117-
_case.getConstructor()->getArguments()[i]->print(argSort);
118-
result += ":" + argSort.str();
114+
result += fmt::format(
115+
"{}Var'Unds':{}", conn,
116+
ast_to_string(*_case.getConstructor()->getArguments()[i]));
119117
conn = ",";
120118
}
121119
result += ")";
122-
return std::make_pair(returnSort.str(), result);
120+
121+
auto return_sort = ast_to_string(*_case.getConstructor()->getSort());
122+
return std::make_pair(return_sort, result);
123123
}
124124
}
125125

@@ -732,18 +732,15 @@ void makeEvalOrAnywhereFunction(
732732
auto returnSort = dynamic_cast<KORECompositeSort *>(function->getSort().get())
733733
->getCategory(definition);
734734
auto returnType = getParamType(returnSort, module);
735-
std::ostringstream Out;
736-
function->getSort()->print(Out);
737-
auto debugReturnType = getDebugType(returnSort, Out.str());
735+
auto debugReturnType
736+
= getDebugType(returnSort, ast_to_string(*function->getSort()));
738737
std::vector<llvm::Type *> args;
739738
std::vector<llvm::Metadata *> debugArgs;
740739
std::vector<ValueType> cats;
741740
for (auto &sort : function->getArguments()) {
742741
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
743742
->getCategory(definition);
744-
std::ostringstream Out;
745-
sort->print(Out);
746-
debugArgs.push_back(getDebugType(cat, Out.str()));
743+
debugArgs.push_back(getDebugType(cat, ast_to_string(*sort)));
747744
switch (cat.cat) {
748745
case SortCategory::Map:
749746
case SortCategory::RangeMap:
@@ -760,9 +757,7 @@ void makeEvalOrAnywhereFunction(
760757
}
761758
llvm::FunctionType *funcType
762759
= llvm::FunctionType::get(returnType, args, false);
763-
std::ostringstream Out2;
764-
function->print(Out2, 0, false);
765-
std::string name = "eval_" + Out2.str();
760+
std::string name = fmt::format("eval_{}", ast_to_string(*function, 0, false));
766761
llvm::Function *matchFunc = getOrInsertFunction(module, name, funcType);
767762
KORESymbolDeclaration *symbolDecl
768763
= definition->getSymbolDeclarations().at(function->getName());
@@ -791,9 +786,9 @@ void makeEvalOrAnywhereFunction(
791786
++val, ++i) {
792787
val->setName("_" + std::to_string(i + 1));
793788
codegen.store(std::make_pair(val->getName().str(), val->getType()), val);
794-
std::ostringstream Out;
795-
function->getArguments()[i]->print(Out);
796-
initDebugParam(matchFunc, i, val->getName().str(), cats[i], Out.str());
789+
initDebugParam(
790+
matchFunc, i, val->getName().str(), cats[i],
791+
ast_to_string(*function->getArguments()[i]));
797792
}
798793
addStuck(stuck, module, function, codegen, definition);
799794

@@ -804,9 +799,7 @@ void abortWhenStuck(
804799
llvm::BasicBlock *CurrentBlock, llvm::Module *Module, KORESymbol *symbol,
805800
Decision &codegen, KOREDefinition *d) {
806801
auto &Ctx = Module->getContext();
807-
std::ostringstream Out;
808-
symbol->print(Out);
809-
symbol = d->getAllSymbols().at(Out.str());
802+
symbol = d->getAllSymbols().at(ast_to_string(*symbol));
810803
auto BlockType = getBlockType(Module, d, symbol);
811804
llvm::Value *Ptr;
812805
auto BlockPtr = llvm::PointerType::getUnqual(
@@ -1276,9 +1269,7 @@ void makeStepFunction(
12761269
auto argSort
12771270
= dynamic_cast<KORECompositeSort *>(res.pattern->getSort().get());
12781271
auto cat = argSort->getCategory(definition);
1279-
std::ostringstream Out;
1280-
argSort->print(Out);
1281-
debugTypes.push_back(getDebugType(cat, Out.str()));
1272+
debugTypes.push_back(getDebugType(cat, ast_to_string(*argSort)));
12821273
switch (cat.cat) {
12831274
case SortCategory::Map:
12841275
case SortCategory::RangeMap:
@@ -1334,9 +1325,8 @@ void makeStepFunction(
13341325
auto cat = dynamic_cast<KORECompositeSort *>(sort.get())
13351326
->getCategory(definition);
13361327
types.push_back(cat);
1337-
std::ostringstream Out;
1338-
sort->print(Out);
1339-
initDebugParam(matchFunc, i, "_" + std::to_string(i + 1), cat, Out.str());
1328+
initDebugParam(
1329+
matchFunc, i, "_" + std::to_string(i + 1), cat, ast_to_string(*sort));
13401330
}
13411331
auto header = stepFunctionHeader(
13421332
axiom->getOrdinal(), module, definition, block, stuck, args, types);

0 commit comments

Comments
 (0)