Skip to content

Commit

Permalink
Emit table of return sorts (#911)
Browse files Browse the repository at this point in the history
Part of: #905

In #905, we are implementing a Python binding for the backend's function
evaluator: given a function label and list of argument `Pattern`s,
construct runtime terms for the arguments, evaluate the function with
the given label, and return the result as an AST pattern.

To safely reify the runtime term produced by the function call to an AST
pattern, we need to know its sort (so that the machinery in #907, #908
can be used correctly). In some places in the bindings, we have to
require that callers provide a sort when reifying terms back to
patterns. However, when calling a function, the label of the function
determines precisely the correct sort to use.

This PR emits a new table of global data into compiled interpreters that
maps tags to declared return sorts, along with a function that abstracts
away indexing into this table. This change is similar to (but simpler
than) an existing table of _argument sorts_ for each symbol that we
already emit.

Testing is handled by binding the new function to Python.
  • Loading branch information
Baltoli authored Dec 1, 2023
1 parent 4cbbb3a commit 44f29b2
Show file tree
Hide file tree
Showing 13 changed files with 6,268 additions and 0 deletions.
5 changes: 5 additions & 0 deletions bindings/core/src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *);

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label) {
auto tag = getTagForSymbolName(label.c_str());
return getReturnSortForTag(tag);
}

std::shared_ptr<KOREPattern> make_injection(
std::shared_ptr<KOREPattern> term, std::shared_ptr<KORESort> from,
std::shared_ptr<KORESort> to) {
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ void bind_runtime(py::module_ &m) {
m.def("simplify_pattern", bindings::simplify);
m.def("simplify_bool_pattern", bindings::simplify_to_bool);

m.def("return_sort_for_label", bindings::return_sort_for_label);

// This class can't be used directly from Python; the mutability semantics
// that we get from the Pybind wrappers make it really easy to break things.
// We therefore have to wrap it up in some external Python code; see
Expand Down
12 changes: 12 additions & 0 deletions include/kllvm/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ using sptr = std::shared_ptr<T>;

std::string decodeKore(std::string);

/*
* Helper function to avoid repeated call-site uses of ostringstream when we
* just want the string representation of a node, rather than to print it to a
* stream.
*/
template <typename T>
std::string ast_to_string(T &&node) {
auto os = std::ostringstream{};
std::forward<T>(node).print(os);
return os.str();
}

// KORESort
class KORESort : public std::enable_shared_from_this<KORESort> {
public:
Expand Down
2 changes: 2 additions & 0 deletions include/kllvm/bindings/core/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label);

std::shared_ptr<kllvm::KOREPattern> make_injection(
std::shared_ptr<kllvm::KOREPattern> term,
std::shared_ptr<kllvm::KORESort> from, std::shared_ptr<kllvm::KORESort> to);
Expand Down
1 change: 1 addition & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag);
bool hook_STRING_eq(SortString, SortString);

const char *getSymbolNameForTag(uint32_t tag);
const char *getReturnSortForTag(uint32_t tag);
const char *topSort(void);

typedef struct {
Expand Down
50 changes: 50 additions & 0 deletions lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,55 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
}
}

/*
* Emit a table mapping symbol tags to the declared return sort for that symbol.
* For example:
*
* tag_of(initGeneratedTopCell) |-> sort_name_SortGeneratedTopCell{}
*
* Each value in the table is a pointer to a global variable containing the
* relevant sort name as a null-terminated string.
*
* The function `getReturnSortForTag` abstracts accesses to the data in this
* table.
*/
static void
emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) {
auto &ctx = module->getContext();

auto const &syms = definition->getSymbols();

auto element_type = llvm::Type::getInt8PtrTy(ctx);
auto table_type = llvm::ArrayType::get(element_type, syms.size());

auto table = module->getOrInsertGlobal("return_sort_table", table_type);
auto values = std::vector<llvm::Constant *>{};

for (auto [tag, symbol] : syms) {
auto sort = symbol->getSort();
auto sort_str = ast_to_string(*sort);

auto char_type = llvm::Type::getInt8Ty(ctx);
auto str_type = llvm::ArrayType::get(char_type, sort_str.size() + 1);

auto sort_name
= module->getOrInsertGlobal("sort_name_" + sort_str, str_type);

auto i64_type = llvm::Type::getInt64Ty(ctx);
auto zero = llvm::ConstantInt::get(i64_type, 0);

auto pointer = llvm::ConstantExpr::getInBoundsGetElementPtr(
str_type, sort_name, std::vector<llvm::Constant *>{zero});

values.push_back(pointer);
}

auto global = llvm::dyn_cast<llvm::GlobalVariable>(table);
if (!global->hasInitializer()) {
global->setInitializer(llvm::ConstantArray::get(table_type, values));
}
}

void emitConfigParserFunctions(
KOREDefinition *definition, llvm::Module *module) {
emitGetTagForSymbolName(definition, module);
Expand All @@ -1329,6 +1378,7 @@ void emitConfigParserFunctions(
emitInjTags(definition, module);

emitSortTable(definition, module);
emitReturnSortTable(definition, module);
}

} // namespace kllvm
6 changes: 6 additions & 0 deletions runtime/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

extern "C" {

extern char *return_sort_table;

const char *getReturnSortForTag(uint32_t tag) {
return (&return_sort_table)[tag];
}

block *dot_k() {
return leaf_block(getTagForSymbolName("dotk{}"));
}
Expand Down
6,145 changes: 6,145 additions & 0 deletions test/python/Inputs/sorts.kore

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions test/python/k-files/sorts.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SORTS
imports DOMAINS

syntax Int ::= func() [function, label(func), symbol]

syntax Foo ::= foo() [label(foo), symbol]
syntax Bar ::= Foo
| bar() [label(bar), symbol]

rule func() => 0
rule foo() => bar()
endmodule
26 changes: 26 additions & 0 deletions test/python/test_return_sorts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# RUN: mkdir -p %t
# RUN: export IN=$(realpath Inputs/sorts.kore)
# RUN: cd %t && %kompile "$IN" python --python %py-interpreter --python-output-dir .
# RUN: KLLVM_DEFINITION=%t %python -u %s

from test_bindings import kllvm

import unittest

class TestReturnSorts(unittest.TestCase):

def _check_sort(self, label, sort):
self.assertEqual(kllvm.runtime.return_sort_for_label(label), sort)

def test_function(self):
self._check_sort('Lblfunc{}', 'SortInt{}')

def test_constructor(self):
self._check_sort('Lblfoo{}', 'SortFoo{}')

def test_subsort(self):
self._check_sort('Lblbar{}', 'SortBar{}')


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions unittests/runtime-ffi/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define KCHAR char
#define TYPETAG(type) "Lbl'Hash'ffi'Unds'" #type "{}"

char *return_sort_table = nullptr;

void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
return nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions unittests/runtime-io/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#define KCHAR char

char *return_sort_table = nullptr;

void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
return nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions unittests/runtime-strings/bytestest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#define KCHAR char
extern "C" {

char *return_sort_table = nullptr;

uint32_t getTagForSymbolName(const char *s) {
return 0;
}
Expand Down

0 comments on commit 44f29b2

Please sign in to comment.