diff --git a/include/vast/Conversion/Parser/Passes.hpp b/include/vast/Conversion/Parser/Passes.hpp index e0f3c2f5eb..4da64db19d 100644 --- a/include/vast/Conversion/Parser/Passes.hpp +++ b/include/vast/Conversion/Parser/Passes.hpp @@ -4,6 +4,8 @@ #include "vast/Util/Warnings.hpp" +#include "vast/server/server.hpp" + VAST_RELAX_WARNINGS #include #include diff --git a/include/vast/Conversion/Parser/Passes.td b/include/vast/Conversion/Parser/Passes.td index 585fb54b82..2cdd1abb57 100644 --- a/include/vast/Conversion/Parser/Passes.td +++ b/include/vast/Conversion/Parser/Passes.td @@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> { let options = [ Option< "config", "config", "std::string", "", "Configuration file for parser transformation." + >, + Option< "socket", "socket", "std::string", "", + "Unix socket path to use for server" > ]; diff --git a/include/vast/server/types.hpp b/include/vast/server/types.hpp index 337e9fab11..2c6370fab1 100644 --- a/include/vast/server/types.hpp +++ b/include/vast/server/types.hpp @@ -4,11 +4,34 @@ #include #include +#include #include #include #include +namespace nlohmann { + template< typename T > + struct adl_serializer< std::optional< T > > + { + static void to_json(json &j, const std::optional< T > &opt) { + if (!opt.has_value()) { + j = nullptr; + } else { + j = *opt; + } + } + + static void from_json(const json &j, std::optional< T > &opt) { + if (j.is_null()) { + opt = std::nullopt; + } else { + opt = j.template get< T >(); + } + } + }; +} // namespace nlohmann + namespace vast::server { template< typename T > concept json_convertible = requires(T obj, nlohmann::json &json) { @@ -88,6 +111,22 @@ namespace vast::server { template< request_like request > using result_type = std::variant< typename request::response_type, error< request > >; + struct position + { + unsigned int line; + unsigned int character; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character) + }; + + struct range + { + position start; + position end; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end) + }; + struct input_request { static constexpr const char *method = "input"; @@ -95,7 +134,10 @@ namespace vast::server { nlohmann::json type; std::string text; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text) + std::optional< std::string > filePath; + std::optional< range > range; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range) struct response_type { diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index dcb7c72f3c..1cf15de2ae 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS #include "vast/Conversion/Parser/Config.hpp" +#include "vast/server/server.hpp" +#include "vast/server/types.hpp" + #include namespace vast::conv { @@ -75,6 +78,141 @@ namespace vast::conv { using function_models = llvm::StringMap< function_model >; + struct location + { + std::string filePath; + server::range range; + }; + + location get_location(file_loc_t loc) { + return { + .filePath = loc.getFilename().str(), + .range = { + .start = { loc.getLine(), loc.getColumn(), }, + .end = { loc.getLine(), loc.getColumn(), }, + }, + }; + } + + location get_location(name_loc_t loc) { + return get_location(mlir::cast< file_loc_t >(loc.getChildLoc())); + } + + std::optional< location > get_location(loc_t loc) { + if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) { + return get_location(file_loc); + } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) { + return get_location(name_loc); + } + + return std::nullopt; + } + + pr::data_type parse_type_name(const std::string &name) { + if (name == "data") { + return pr::data_type::data; + } else if (name == "nodata") { + return pr::data_type::nodata; + } else { + return pr::data_type::maybedata; + } + } + + function_category ask_user_for_category(vast::server::server_base &server, hl::FuncOp op) { + auto loc = op.getLoc(); + + vast::server::input_request req{ + .type = {"nonparser", "sink", "source", "parser",}, + .text = "Please choose category for function `" + op.getSymName().str() + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + if (result->value == "nonparser") { + return function_category::nonparser; + } else if (result->value == "sink") { + return function_category::sink; + } else if (result->value == "source") { + return function_category::source; + } else if (result->value == "parser") { + return function_category::parser; + } + } + return function_category::nonparser; + } + + pr::data_type ask_user_for_return_type(vast::server::server_base &server, hl::FuncOp op) { + auto loc = op.getLoc(); + + vast::server::input_request req{ + .type = { "maybedata", "nodata", "data" }, + .text = "Please choose return type for function `" + op.getSymName().str() + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + return parse_type_name(result->value); + } + return pr::data_type::maybedata; + } + + pr::data_type ask_user_for_argument_type( + vast::server::server_base &server, hl::FuncOp op, unsigned int idx + ) { + auto num_body_args = op.getFunctionBody().getNumArguments(); + + vast::server::input_request req{ + .type = { "maybedata", "nodata", "data" }, + .text = "Please choose a type for argument " + std::to_string(idx) + + " of function `" + op.getSymName().str() + '`', + .filePath = std::nullopt, + .range = std::nullopt, + }; + + if (idx < num_body_args) { + auto arg = op.getArgument(idx); + auto loc = arg.getLoc(); + if (auto req_loc = get_location(loc)) { + req.filePath = req_loc->filePath; + req.range = req_loc->range; + } + } + + auto response = server.send_request(req); + if (auto result = std::get_if< vast::server::input_request::response_type >(&response)) + { + return parse_type_name(result->value); + } + return pr::data_type::maybedata; + } + + function_model + ask_user_for_function_model(vast::server::server_base &server, hl::FuncOp op) { + function_model model; + model.return_type = ask_user_for_return_type(server, op); + for (unsigned int i = 0; i < op.getNumArguments(); ++i) { + model.arguments.push_back(ask_user_for_argument_type(server, op, i)); + } + model.category = ask_user_for_category(server, op); + return model; + } + } // namespace vast::conv LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type); @@ -130,25 +268,28 @@ namespace vast::conv { using base = base_conversion_config; parser_conversion_config( - rewrite_pattern_set patterns, conversion_target target, - const function_models &models + rewrite_pattern_set patterns, conversion_target target, function_models &models, + vast::server::server_base *server ) - : base(std::move(patterns), std::move(target)), models(models) - {} + : base(std::move(patterns), std::move(target)), models(models), server(server) {} template< typename pattern > void add_pattern() { auto ctx = patterns.getContext(); if constexpr (std::is_constructible_v< pattern, mcontext_t * >) { patterns.template add< pattern >(ctx); - } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) { - patterns.template add< pattern >(ctx, models); + } else if constexpr (std::is_constructible_v< + pattern, mcontext_t *, function_models &, + vast::server::server_base * >) + { + patterns.template add< pattern >(ctx, models, server); } else { static_assert(false, "pattern does not have a valid constructor"); } } - const function_models ⊧ + function_models ⊧ + vast::server::server_base *server; }; struct function_type_converter @@ -277,24 +418,33 @@ namespace vast::conv { { using base = mlir::OpConversionPattern< op_t >; - parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models) - : base(mctx), models(models) - {} + parser_conversion_pattern_base( + mcontext_t *mctx, function_models &models, vast::server::server_base *server + ) + : base(mctx), models(models), server(server) {} - static std::optional< function_model > - get_model(const function_models &models, hl::FuncOp func) { + static std::optional< function_model > get_model( + function_models &models, hl::FuncOp func, vast::server::server_base *server + ) { if (auto kv = models.find(func.getSymName()); kv != models.end()) { return kv->second; } + if (server) { + auto model = ask_user_for_function_model(*server, func); + models[func.getSymName()] = model; + return model; + } + return std::nullopt; } std::optional< function_model > get_model(hl::FuncOp func) const { - return get_model(models, func); + return get_model(models, func, server); } - const function_models ⊧ + function_models ⊧ + vast::server::server_base *server; }; // @@ -541,10 +691,13 @@ namespace vast::conv { return mlir::failure(); } - static void legalize(parser_conversion_config &cfg) { + static void + legalize(parser_conversion_config &cfg, vast::server::server_base *server) { cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); - cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) { - return function_type_converter(*op.getContext(), get_model(models, op)) + cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) { + return function_type_converter( + *op.getContext(), get_model(cfg.models, op, server) + ) .isLegal(op.getFunctionType()); }); } @@ -722,6 +875,9 @@ namespace vast::conv { { using base = ConversionPassMixin< HLToParserPass, HLToParserBase >; + struct server_handler + {}; + static conversion_target create_conversion_target(mcontext_t &mctx) { return conversion_target(mctx); } @@ -736,6 +892,12 @@ namespace vast::conv { if (!config.empty()) { load_and_parse(config); } + + if (!socket.empty()) { + server = std::make_shared< vast::server::server< server_handler > >( + vast::server::sock_adapter::create_unix_socket(socket) + ); + } } void load_and_parse(string_ref config) { @@ -762,10 +924,12 @@ namespace vast::conv { parser_conversion_config make_config() { auto &ctx = getContext(); - return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models }; + return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models, + server.get() }; } function_models models; + std::shared_ptr< vast::server::server< server_handler > > server; }; } // namespace vast::conv