Skip to content

Commit

Permalink
pr: Add option to use interactive server in parser conversion pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
frabert committed Jan 31, 2025
1 parent 34718b2 commit e6298fb
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 19 deletions.
2 changes: 2 additions & 0 deletions include/vast/Conversion/Parser/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "vast/Util/Warnings.hpp"

#include "vast/server/server.hpp"

VAST_RELAX_WARNINGS
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
Expand Down
3 changes: 3 additions & 0 deletions include/vast/Conversion/Parser/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
let options = [
Option< "config", "config", "std::string", "",
"Configuration file for parser transformation."
>,
Option< "socket", "socket", "std::string", "",
"Unix socket path to use for server"
>
];

Expand Down
44 changes: 43 additions & 1 deletion include/vast/server/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,34 @@

#include <concepts>
#include <cstdint>
#include <optional>
#include <string>
#include <variant>

#include <nlohmann/json.hpp>

namespace nlohmann {
template< typename T >
struct adl_serializer< std::optional< T > >
{
static void to_json(json &j, const std::optional< T > &opt) {
if (!opt.has_value()) {
j = nullptr;
} else {
j = *opt;
}
}

static void from_json(const json &j, std::optional< T > &opt) {
if (j.is_null()) {
opt = std::nullopt;
} else {
opt = j.template get< T >();
}
}
};
} // namespace nlohmann

namespace vast::server {
template< typename T >
concept json_convertible = requires(T obj, nlohmann::json &json) {
Expand Down Expand Up @@ -88,14 +111,33 @@ namespace vast::server {
template< request_like request >
using result_type = std::variant< typename request::response_type, error< request > >;

struct position
{
unsigned int line;
unsigned int character;

NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
};

struct range
{
position start;
position end;

NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
};

struct input_request
{
static constexpr const char *method = "input";
static constexpr bool is_notification = false;

nlohmann::json type;
std::string text;
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
std::optional< std::string > filePath;
std::optional< range > range;

NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)

struct response_type
{
Expand Down
214 changes: 196 additions & 18 deletions lib/vast/Conversion/Parser/ToParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS

#include "vast/Conversion/Parser/Config.hpp"

#include "vast/server/server.hpp"
#include "vast/server/types.hpp"

#include <ranges>

namespace vast::conv {
Expand Down Expand Up @@ -75,6 +78,154 @@ namespace vast::conv {

using function_models = llvm::StringMap< function_model >;

struct location
{
std::string filePath;
server::range range;
};

location get_location(file_loc_t loc) {
return {
.filePath = loc.getFilename().str(),
.range = {
.start = { loc.getLine(), loc.getColumn(), },
.end = { loc.getLine(), loc.getColumn(), },
},
};
}

location get_location(name_loc_t loc) {
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
}

std::optional< location > get_location(loc_t loc) {
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
return get_location(file_loc);
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
return get_location(name_loc);
}

return std::nullopt;
}

pr::data_type parse_type_name(const std::string &name) {
if (name == "data") {
return pr::data_type::data;
} else if (name == "nodata") {
return pr::data_type::nodata;
} else {
return pr::data_type::maybedata;
}
}

function_category
ask_user_for_category(vast::server::server_base &server, core::function_op_interface op) {
auto loc = op.getLoc();
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
VAST_ASSERT(sym);
auto name = sym.getSymbolName().str();

vast::server::input_request req{
.type = {"nonparser", "sink", "source", "parser",},
.text = "Please choose category for function `" + name + '`',
.filePath = std::nullopt,
.range = std::nullopt,
};

if (auto req_loc = get_location(loc)) {
req.filePath = req_loc->filePath;
req.range = req_loc->range;
}

auto response = server.send_request(req);
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
{
if (result->value == "nonparser") {
return function_category::nonparser;
} else if (result->value == "sink") {
return function_category::sink;
} else if (result->value == "source") {
return function_category::source;
} else if (result->value == "parser") {
return function_category::parser;
}
}
return function_category::nonparser;
}

pr::data_type ask_user_for_return_type(
vast::server::server_base &server, core::function_op_interface op
) {
auto loc = op.getLoc();
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
VAST_ASSERT(sym);
auto name = sym.getSymbolName().str();

vast::server::input_request req{
.type = { "maybedata", "nodata", "data" },
.text = "Please choose return type for function `" + name + '`',
.filePath = std::nullopt,
.range = std::nullopt,
};

if (auto req_loc = get_location(loc)) {
req.filePath = req_loc->filePath;
req.range = req_loc->range;
}

auto response = server.send_request(req);
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
{
return parse_type_name(result->value);
}
return pr::data_type::maybedata;
}

pr::data_type ask_user_for_argument_type(
vast::server::server_base &server, core::function_op_interface op, unsigned int idx
) {
auto num_body_args = op.getFunctionBody().getNumArguments();
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
VAST_ASSERT(sym);
auto name = sym.getSymbolName().str();

vast::server::input_request req{
.type = { "maybedata", "nodata", "data" },
.text = "Please choose a type for argument " + std::to_string(idx)
+ " of function `" + name + '`',
.filePath = std::nullopt,
.range = std::nullopt,
};

if (idx < num_body_args) {
auto arg = op.getArgument(idx);
auto loc = arg.getLoc();
if (auto req_loc = get_location(loc)) {
req.filePath = req_loc->filePath;
req.range = req_loc->range;
}
}

auto response = server.send_request(req);
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
{
return parse_type_name(result->value);
}
return pr::data_type::maybedata;
}

function_model ask_user_for_function_model(
vast::server::server_base &server, core::function_op_interface op
) {
function_model model;
model.return_type = ask_user_for_return_type(server, op);
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
}
model.category = ask_user_for_category(server, op);
return model;
}

} // namespace vast::conv

LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
Expand Down Expand Up @@ -130,25 +281,28 @@ namespace vast::conv {
using base = base_conversion_config;

parser_conversion_config(
rewrite_pattern_set patterns, conversion_target target,
const function_models &models
rewrite_pattern_set patterns, conversion_target target, function_models &models,
vast::server::server_base *server
)
: base(std::move(patterns), std::move(target)), models(models)
{}
: base(std::move(patterns), std::move(target)), models(models), server(server) {}

template< typename pattern >
void add_pattern() {
auto ctx = patterns.getContext();
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
patterns.template add< pattern >(ctx);
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
patterns.template add< pattern >(ctx, models);
} else if constexpr (std::is_constructible_v<
pattern, mcontext_t *, function_models &,
vast::server::server_base * >)
{
patterns.template add< pattern >(ctx, models, server);
} else {
static_assert(false, "pattern does not have a valid constructor");
}
}

const function_models &models;
function_models &models;
vast::server::server_base *server;
};

struct function_type_converter
Expand Down Expand Up @@ -277,26 +431,36 @@ namespace vast::conv {
{
using base = mlir::OpConversionPattern< op_t >;

parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
: base(mctx), models(models)
{}
parser_conversion_pattern_base(
mcontext_t *mctx, function_models &models, vast::server::server_base *server
)
: base(mctx), models(models), server(server) {}

static std::optional< function_model >
get_model(const function_models &models, core::function_op_interface op) {
static std::optional< function_model > get_model(
function_models &models, core::function_op_interface op,
vast::server::server_base *server
) {
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
VAST_ASSERT(sym);
if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) {
return kv->second;
}

if (server) {
auto model = ask_user_for_function_model(*server, op);
models[sym.getSymbolName()] = model;
return model;
}

return std::nullopt;
}

std::optional< function_model > get_model(core::function_op_interface op) const {
return get_model(models, op);
return get_model(models, op, server);
}

const function_models &models;
function_models &models;
vast::server::server_base *server;
};

//
Expand Down Expand Up @@ -543,10 +707,13 @@ namespace vast::conv {
return mlir::failure();
}

static void legalize(parser_conversion_config &cfg) {
static void
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
return function_type_converter(*op.getContext(), get_model(models, op))
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
return function_type_converter(
*op.getContext(), get_model(cfg.models, op, server)
)
.isLegal(op.getFunctionType());
});
}
Expand Down Expand Up @@ -724,6 +891,9 @@ namespace vast::conv {
{
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;

struct server_handler
{};

static conversion_target create_conversion_target(mcontext_t &mctx) {
return conversion_target(mctx);
}
Expand All @@ -738,6 +908,12 @@ namespace vast::conv {
if (!config.empty()) {
load_and_parse(config);
}

if (!socket.empty()) {
server = std::make_shared< vast::server::server< server_handler > >(
vast::server::sock_adapter::create_unix_socket(socket)
);
}
}

void load_and_parse(string_ref config) {
Expand All @@ -764,10 +940,12 @@ namespace vast::conv {

parser_conversion_config make_config() {
auto &ctx = getContext();
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
server.get() };
}

function_models models;
std::shared_ptr< vast::server::server< server_handler > > server;
};

} // namespace vast::conv
Expand Down

0 comments on commit e6298fb

Please sign in to comment.