Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose function evaluation to bindings #905

Merged
merged 12 commits into from
Dec 18, 2023
15 changes: 15 additions & 0 deletions bindings/core/src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ simplify(std::shared_ptr<KOREPattern> pattern, std::shared_ptr<KORESort> sort) {
return term_to_pattern(simplify_to_term(pattern, sort));
}

std::shared_ptr<KOREPattern>
evaluate_function(std::shared_ptr<KORECompositePattern> term) {
auto term_args = std::vector<void *>{};
for (auto const &arg : term->getArguments()) {
term_args.push_back(static_cast<void *>(construct_term(arg)));
}

auto label = ast_to_string(*term->getConstructor());
auto tag = getTagForSymbolName(label.c_str());
auto return_sort = getReturnSortForTag(tag);
auto result = evaluateFunctionSymbol(tag, term_args.data());

return sortedTermToKorePattern(static_cast<block *>(result), return_sort);
}

bool is_sort_kitem(std::shared_ptr<KORESort> const &sort) {
if (auto composite = std::dynamic_pointer_cast<KORECompositeSort>(sort)) {
return composite->getName() == "SortKItem";
Expand Down
3 changes: 3 additions & 0 deletions bindings/python/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <kllvm/bindings/core/core.h>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

// This header needs to be included last because it pollutes a number of macro
// definitions into the global namespace.
Expand Down Expand Up @@ -55,6 +56,8 @@ void bind_runtime(py::module_ &m) {

m.def("return_sort_for_label", bindings::return_sort_for_label);

m.def("evaluate_function", bindings::evaluate_function);

// This class can't be used directly from Python; the mutability semantics
// that we get from the Pybind wrappers make it really easy to break things.
// We therefore have to wrap it up in some external Python code; see
Expand Down
3 changes: 3 additions & 0 deletions include/kllvm/bindings/core/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ std::shared_ptr<kllvm::KOREPattern> simplify(
bool is_sort_kitem(std::shared_ptr<kllvm::KORESort> const &sort);
bool is_sort_k(std::shared_ptr<kllvm::KORESort> const &sort);

std::shared_ptr<KOREPattern>
evaluate_function(std::shared_ptr<KORECompositePattern> term);
Baltoli marked this conversation as resolved.
Show resolved Hide resolved

} // namespace kllvm::bindings

#endif
4 changes: 3 additions & 1 deletion include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ void printConfigurationInternal(
// you can use the C bindings, which wrap the return value of this method in
// a POD struct.
std::shared_ptr<kllvm::KOREPattern> termToKorePattern(block *);
std::shared_ptr<kllvm::KOREPattern>
sortedTermToKorePattern(block *, const char *);

// This function injects its argument into KItem before printing, using the sort
// argument as the source sort. Doing so allows the term to be pretty-printed
Expand Down Expand Up @@ -409,7 +411,7 @@ extern const uint32_t first_inj_tag, last_inj_tag;
bool is_injection(block *);
block *strip_injection(block *);
block *constructKItemInj(void *subject, const char *sort, bool raw_value);
block *constructRawTerm(void *subject, const char *sort);
block *constructRawTerm(void *subject, const char *sort, bool raw_value);
}

std::string floatToString(const floating *);
Expand Down
14 changes: 11 additions & 3 deletions runtime/util/ConfigurationSerializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ void serializeTermToFile(

void serializeRawTermToFile(
const char *filename, void *subject, const char *sort) {
block *term = constructRawTerm(subject, sort);
block *term = constructRawTerm(subject, sort, true);

char *data;
size_t size;
Expand All @@ -471,13 +471,21 @@ void serializeRawTermToFile(
fclose(file);
}

std::shared_ptr<kllvm::KOREPattern> termToKorePattern(block *subject) {
std::shared_ptr<kllvm::KOREPattern>
sortedTermToKorePattern(block *subject, const char *sort) {
auto is_kitem = (std::string(sort) == "SortKItem{}");
block *term = is_kitem ? subject : constructRawTerm(subject, sort, false);

char *data_out;
size_t size_out;

serializeConfiguration(subject, "SortKItem{}", &data_out, &size_out, true);
serializeConfiguration(term, "SortKItem{}", &data_out, &size_out, true);
auto result = deserialize_pattern(data_out, data_out + size_out);

free(data_out);
return result;
}

std::shared_ptr<kllvm::KOREPattern> termToKorePattern(block *subject) {
return sortedTermToKorePattern(subject, "SortKItem{}");
}
4 changes: 2 additions & 2 deletions runtime/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ block *constructKItemInj(void *subject, const char *sort, bool raw_value) {
return static_cast<block *>(constructCompositePattern(tag, args));
}

block *constructRawTerm(void *subject, const char *sort) {
block *constructRawTerm(void *subject, const char *sort, bool raw_value) {
auto tag = getTagForSymbolName("rawTerm{}");
auto args = std::vector{
static_cast<void *>(constructKItemInj(subject, sort, true))};
static_cast<void *>(constructKItemInj(subject, sort, raw_value))};
return static_cast<block *>(constructCompositePattern(tag, args));
}

Expand Down
2,213 changes: 2,213 additions & 0 deletions test/python/Inputs/evaluate.kore

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions test/python/k-files/evaluate.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module EVALUATE
imports INT
imports BOOL

syntax Foo ::= foo(Int) [label(foo), symbol]
| bar(Int) [label(bar), symbol]

syntax Foo ::= f(Foo, Foo) [function, label(f), symbol]
rule f(foo(A), bar(B)) => foo(A +Int B)

syntax Int ::= baz() [function, label(baz), symbol]
syntax Bool ::= qux(Int) [function, label(qux), symbol]

rule baz() => 78
rule qux(I) => I ==Int 34
endmodule
45 changes: 45 additions & 0 deletions test/python/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# RUN: mkdir -p %t
# RUN: export IN=$(realpath Inputs/evaluate.kore)
# RUN: cd %t && %kompile "$IN" python --python %py-interpreter --python-output-dir .
# RUN: KLLVM_DEFINITION=%t %python -u %s

from test_bindings import kllvm

import unittest


class TestEvaluate(unittest.TestCase):

def test_ctor(self):
call = kllvm.parser.Parser.from_string('Lblf{}(Lblfoo{}(\\dv{SortInt{}}("23")),Lblbar{}(\\dv{SortInt{}}("56")))').pattern()

result = kllvm.runtime.evaluate_function(call)
self.assertEqual(str(result), 'Lblfoo{}(\dv{SortInt{}}("79"))')

def test_int_function(self):
call = kllvm.ast.CompositePattern('Lblbaz')

result = kllvm.runtime.evaluate_function(call)
self.assertEqual(str(result), '\dv{SortInt{}}("78")')

def test_true_function(self):
arg_t = kllvm.parser.Parser.from_string('\\dv{SortInt{}}("34")').pattern()

call = kllvm.ast.CompositePattern('Lblqux')
call.add_argument(arg_t)

result_t = kllvm.runtime.evaluate_function(call)
self.assertEqual(str(result_t), '\dv{SortBool{}}("true")')

def test_false_function(self):
arg_f = kllvm.parser.Parser.from_string('\\dv{SortInt{}}("98")').pattern()

call = kllvm.ast.CompositePattern('Lblqux')
call.add_argument(arg_f)

result_f = kllvm.runtime.evaluate_function(call)
self.assertEqual(str(result_f), '\dv{SortBool{}}("false")')


if __name__ == "__main__":
unittest.main()
Loading