From ae3abc7f0e8622f1d66e5efeb1f91fc4529883ba Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sun, 31 Mar 2024 14:05:00 -0400 Subject: [PATCH 01/19] Define and pass in Python callback to native FlightSQLServer implementation --- cpp/pybind/brad_server.cc | 2 ++ cpp/server/brad_server_simple.cc | 14 +++++++++++++- cpp/server/brad_server_simple.h | 8 +++++++- src/brad/front_end/flight_sql_server.py | 4 ++-- src/brad/front_end/front_end.py | 22 ++++++++++++++++++---- src/brad/row_list.py | 3 +++ 6 files changed, 45 insertions(+), 8 deletions(-) diff --git a/cpp/pybind/brad_server.cc b/cpp/pybind/brad_server.cc index be006996..3c8c2ac5 100644 --- a/cpp/pybind/brad_server.cc +++ b/cpp/pybind/brad_server.cc @@ -1,4 +1,6 @@ #include +#include +#include #include diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 19321e76..f77e3bca 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include #include "brad_sql_info.h" @@ -59,9 +62,15 @@ std::shared_ptr return result; } -void BradFlightSqlServer::InitWrapper(const std::string &host, int port) { +void BradFlightSqlServer::InitWrapper( + const std::string &host, + int port, + std::function>(std::string)> handle_query) { auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie(); arrow::flight::FlightServerOptions options(location); + + _handle_query = handle_query; + this->Init(options); } @@ -79,6 +88,9 @@ arrow::Result> const StatementQuery &command, const FlightDescriptor &descriptor) { const std::string &query = command.query; + + _handle_query(query); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); ARROW_ASSIGN_OR_RAISE(auto ticket, diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index f6db4cbf..4187a91e 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include @@ -17,7 +19,9 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { static std::shared_ptr Create(); - void InitWrapper(const std::string &host, int port); + void InitWrapper(const std::string &host, + int port, + std::function>(std::string)>); void ServeWrapper(); @@ -33,6 +37,8 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { DoGetStatement( const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; + + std::function>(std::string)> _handle_query; }; } // namespace brad diff --git a/src/brad/front_end/flight_sql_server.py b/src/brad/front_end/flight_sql_server.py index eb736e20..279cd79d 100644 --- a/src/brad/front_end/flight_sql_server.py +++ b/src/brad/front_end/flight_sql_server.py @@ -8,9 +8,9 @@ class BradFlightSqlServer: - def __init__(self, host: str, port: int) -> None: + def __init__(self, host: str, port: int, callback) -> None: self._flight_sql_server = brad_server.BradFlightSqlServer() - self._flight_sql_server.init(host, port) + self._flight_sql_server.init(host, port, callback) self._thread = threading.Thread(name="BradFlightSqlServer", target=self._serve) def start(self) -> None: diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index d00b47e0..ac3a4711 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -45,7 +45,7 @@ from brad.routing.policy import RoutingPolicy from brad.routing.router import Router from brad.routing.tree_based.forest_policy import ForestPolicy -from brad.row_list import RowList +from brad.row_list import RowList, FixedRowList from brad.utils import log_verbose, create_custom_logger from brad.utils.counter import Counter from brad.utils.json_decimal_encoder import DecimalEncoder @@ -59,7 +59,6 @@ LINESEP = "\n".encode() - class BradFrontEnd(BradInterface): @staticmethod def native_server_is_supported() -> bool: @@ -90,7 +89,9 @@ def __init__( from brad.front_end.flight_sql_server import BradFlightSqlServer self._flight_sql_server: Optional[BradFlightSqlServer] = ( - BradFlightSqlServer(host="0.0.0.0", port=31337) + BradFlightSqlServer(host="0.0.0.0", + port=31337, + callback=self._handle_query_from_flight_sql) ) else: self._flight_sql_server = None @@ -190,11 +191,21 @@ def __init__( self._is_stub_mode = self._config.stub_mode_path is not None + def _handle_query_from_flight_sql(self, query: str) -> FixedRowList: + future = asyncio.run_coroutine_threadsafe( + self._run_query_impl(self._flight_sql_server_session_id, query, {}), + self._main_thread_loop + ) + row_result = future.result() + + return row_result + async def serve_forever(self): await self._run_setup() # Start FlightSQL server if self._flight_sql_server is not None: + self._flight_sql_server_session_id = await self.start_session() self._flight_sql_server.start() try: @@ -219,6 +230,8 @@ async def serve_forever(self): logger.debug("BRAD front end _run_teardown() complete.") async def _run_setup(self) -> None: + self._main_thread_loop = asyncio.get_running_loop() + # The directory will have been populated by the daemon. await self._blueprint_mgr.load(skip_directory_refresh=True) logger.info("Using blueprint: %s", self._blueprint_mgr.get_blueprint()) @@ -239,7 +252,8 @@ async def _run_setup(self) -> None: if not self._is_stub_mode: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) - self._watchdog.start(asyncio.get_running_loop()) + + self._watchdog.start(self._main_thread_loop) self._ping_watchdog_task = asyncio.create_task(self._ping_watchdog()) async def _set_up_router(self) -> None: diff --git a/src/brad/row_list.py b/src/brad/row_list.py index f070549a..04f8547b 100644 --- a/src/brad/row_list.py +++ b/src/brad/row_list.py @@ -2,3 +2,6 @@ RowList = List[Tuple[Any, ...]] + +# Note: pybind11 does not support the full generic std::any type +FixedRowList = List[Tuple[int]] From c99ae8bebd4a9bc7b7185710a475fb34dde75ef3 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sun, 31 Mar 2024 14:45:13 -0400 Subject: [PATCH 02/19] Create (unprotected) map to store and retrieve query data in GetFlightInfoStatement, DoGetStatement --- cpp/server/brad_server_simple.cc | 16 ++++++++++++++-- cpp/server/brad_server_simple.h | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index f77e3bca..61dd95b4 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -36,6 +36,12 @@ arrow::Result EncodeTransactionQuery( return Ticket{std::move(ticket_string)}; } +std::string GetQueryTicket( + const std::string &query, + const std::string &transaction_id) { + return transaction_id + ':' + query; +} + arrow::Result> DecodeTransactionQuery( const std::string &ticket) { auto divider = ticket.find(':'); @@ -89,12 +95,15 @@ arrow::Result> const FlightDescriptor &descriptor) { const std::string &query = command.query; - _handle_query(query); - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query, command.transaction_id)); + + const std::string &query_ticket = GetQueryTicket(query, command.transaction_id); + const auto query_result = _handle_query(query); + _query_data.insert({query_ticket, query_result}); + std::vector endpoints{ FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; @@ -118,6 +127,9 @@ arrow::Result> const std::string &sql = pair.first; const std::string transaction_id = pair.second; + const std::string &query_ticket = transaction_id + ':' + sql; + const auto query_result = _query_data.at(query_ticket); + std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql)); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 4187a91e..c1ecf377 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -38,7 +38,10 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; + // TODO: Create and reuse type for RowList std::function>(std::string)> _handle_query; + + std::unordered_map>> _query_data; }; } // namespace brad From 65812846abdcd9e6297435205479e9006a6d6a35 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Mon, 1 Apr 2024 18:54:27 -0400 Subject: [PATCH 03/19] Modify brad server to store query result, very minimal schema building from query field types --- cpp/server/brad_server_simple.cc | 10 +++++++--- cpp/server/brad_statement.cc | 27 ++++++++++++++++++++++++--- cpp/server/brad_statement.h | 7 +++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 61dd95b4..0ebacf38 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -95,8 +95,8 @@ arrow::Result> const FlightDescriptor &descriptor) { const std::string &query = command.query; - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); - ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + // ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); + // ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query, command.transaction_id)); @@ -104,6 +104,9 @@ arrow::Result> const auto query_result = _handle_query(query); _query_data.insert({query_ticket, query_result}); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query_result)); + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + std::vector endpoints{ FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; @@ -131,7 +134,8 @@ arrow::Result> const auto query_result = _query_data.at(query_ticket); std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql)); + // ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql)); + ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement)); diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 7f791c5a..1f6ae58e 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -13,6 +13,9 @@ #include #include +#include +#include + namespace brad { using arrow::internal::checked_cast; @@ -24,14 +27,31 @@ arrow::Result> BradStatement::Create( return result; } +arrow::Result> BradStatement::Create( + std::vector> query_result) { + std::shared_ptr result( + new BradStatement(query_result)); + return result; +} + +BradStatement::BradStatement(std::vector> query_result) { + query_result_ = query_result; +} + BradStatement::~BradStatement() { } arrow::Result> BradStatement::GetSchema() const { std::vector> fields; - fields.push_back(arrow::field("Day", arrow::int8())); - fields.push_back(arrow::field("Month", arrow::int8())); - fields.push_back(arrow::field("Year", arrow::int16())); + const auto row = query_result_[0]; + std::string field_type = typeid(std::get<0>(row)).name(); + + if (field_type == "i") { + fields.push_back(arrow::field("Field 1", arrow::int8())); + } else { + fields.push_back(arrow::field("Field 1", arrow::int16())); + } + return arrow::schema(fields); } @@ -56,6 +76,7 @@ arrow::Result> BradStatement::FetchResult() std::shared_ptr record_batch; arrow::Result> result = GetSchema(); + if (result.ok()) { std::shared_ptr schema = result.ValueOrDie(); record_batch = arrow::RecordBatch::Make(schema, diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 482829c9..ba451bfb 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -23,6 +23,11 @@ class BradStatement { static arrow::Result> Create( const std::string& sql); + static arrow::Result> Create( + const std::vector>); + + BradStatement(std::vector>); + ~BradStatement(); /// \brief Creates an Arrow Schema based on the results of this statement. @@ -33,6 +38,8 @@ class BradStatement { std::string* GetBradStmt() const; + std::vector> query_result_; + private: std::string* stmt_; From 71dab24b3491cb9132379055954ab9e2a2a063f0 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Wed, 3 Apr 2024 15:47:05 -0400 Subject: [PATCH 04/19] Move brad_server_lib sources into pybind module, acquire GIL and protect map with mutex in brad_server_simple --- cpp/CMakeLists.txt | 29 ++++++++++------------------- cpp/server/brad_server_simple.cc | 18 ++++++++++++++++-- cpp/server/brad_server_simple.h | 2 ++ 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 95d7fd67..6465e461 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -16,13 +16,6 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) -add_library(brad_server_lib OBJECT - server/brad_server_simple.cc - server/brad_sql_info.cc - server/brad_statement_batch_reader.cc - server/brad_statement.cc - server/brad_tables_schema_batch_reader.cc) - add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc @@ -31,12 +24,18 @@ add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_tables_schema_batch_reader.cc sqlite_server/sqlite_type_info.cc) -pybind11_add_module(pybind_brad_server pybind/brad_server.cc) +pybind11_add_module(pybind_brad_server pybind/brad_server.cc + server/brad_server_simple.cc + server/brad_sql_info.cc + server/brad_statement_batch_reader.cc + server/brad_statement.cc + server/brad_tables_schema_batch_reader.cc) + + target_link_libraries(pybind_brad_server PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared - PRIVATE ArrowFlightSql::arrow_flight_sql_shared - PRIVATE brad_server_lib) + PRIVATE ArrowFlightSql::arrow_flight_sql_shared) add_executable(flight_sql_example_client flight_sql_example_client.cc) target_link_libraries(flight_sql_example_client @@ -55,17 +54,9 @@ target_link_libraries(flight_sql_example_server ${SQLite3_LIBRARIES} ${Boost_LIBRARIES}) -add_executable(flight_sql_brad_server flight_sql_brad_server.cc) -target_link_libraries(flight_sql_brad_server - PRIVATE Arrow::arrow_shared - PRIVATE ArrowFlight::arrow_flight_shared - PRIVATE ArrowFlightSql::arrow_flight_sql_shared - PRIVATE brad_server_lib - gflags) - add_executable(brad_front_end brad_front_end.cc) target_link_libraries(brad_front_end PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared PRIVATE ArrowFlightSql::arrow_flight_sql_shared - gflags) + gflags) \ No newline at end of file diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 0ebacf38..82c349bb 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -19,6 +19,10 @@ #include #include +#include + +namespace py = pybind11; + namespace brad { using arrow::internal::checked_cast; @@ -101,8 +105,18 @@ arrow::Result> EncodeTransactionQuery(query, command.transaction_id)); const std::string &query_ticket = GetQueryTicket(query, command.transaction_id); - const auto query_result = _handle_query(query); - _query_data.insert({query_ticket, query_result}); + + std::vector> query_result; + + { + py::gil_scoped_acquire guard; + query_result = _handle_query(query); + } + + { + std::scoped_lock guard(_query_data_mutex); + _query_data.insert({query_ticket, query_result}); + } ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query_result)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index c1ecf377..36864a2e 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -42,6 +43,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { std::function>(std::string)> _handle_query; std::unordered_map>> _query_data; + std::mutex _query_data_mutex; }; } // namespace brad From 377806ce92714bad2aa7842a120aa83481c503f0 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Fri, 5 Apr 2024 09:08:49 -0400 Subject: [PATCH 05/19] Use autoincrement id in place of queries for ticket identifier --- cpp/server/brad_server_simple.cc | 35 ++++++++++++++++++++------------ cpp/server/brad_server_simple.h | 3 +++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 82c349bb..9333572d 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -30,16 +30,23 @@ using namespace arrow::flight; using namespace arrow::flight::sql; arrow::Result EncodeTransactionQuery( - const std::string &query, - const std::string &transaction_id) { - std::string transaction_query = transaction_id; - transaction_query += ':'; - transaction_query += query; + const std::string &query_ticket) { ARROW_ASSIGN_OR_RAISE(auto ticket_string, - CreateStatementQueryTicket(transaction_query)); + CreateStatementQueryTicket(query_ticket)); return Ticket{std::move(ticket_string)}; } +// arrow::Result EncodeTransactionQuery( +// const std::string &query, +// const std::string &transaction_id) { +// std::string transaction_query = transaction_id; +// transaction_query += ':'; +// transaction_query += query; +// ARROW_ASSIGN_OR_RAISE(auto ticket_string, +// CreateStatementQueryTicket(transaction_query)); +// return Ticket{std::move(ticket_string)}; +// } + std::string GetQueryTicket( const std::string &query, const std::string &transaction_id) { @@ -53,8 +60,8 @@ arrow::Result> DecodeTransactionQuery( return arrow::Status::Invalid("Malformed ticket"); } std::string transaction_id = ticket.substr(0, divider); - std::string query = ticket.substr(divider + 1); - return std::make_pair(std::move(query), std::move(transaction_id)); + std::string autoincrement_id = ticket.substr(divider + 1); + return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } BradFlightSqlServer::BradFlightSqlServer() = default; @@ -101,10 +108,12 @@ arrow::Result> // ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); // ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + const std::string &autoincrement_id = std::to_string(++_autoincrement_id); + const std::string &query_ticket = GetQueryTicket(autoincrement_id, command.transaction_id); + // ARROW_ASSIGN_OR_RAISE(auto ticket, + // EncodeTransactionQuery(autoincrement_id, command.transaction_id)); ARROW_ASSIGN_OR_RAISE(auto ticket, - EncodeTransactionQuery(query, command.transaction_id)); - - const std::string &query_ticket = GetQueryTicket(query, command.transaction_id); + EncodeTransactionQuery(query_ticket)); std::vector> query_result; @@ -141,10 +150,10 @@ arrow::Result> const StatementQueryTicket &command) { ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)); - const std::string &sql = pair.first; + const std::string &autoincrement_id = pair.first; const std::string transaction_id = pair.second; - const std::string &query_ticket = transaction_id + ':' + sql; + const std::string &query_ticket = transaction_id + ':' + autoincrement_id; const auto query_result = _query_data.at(query_ticket); std::shared_ptr statement; diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 36864a2e..705687db 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -44,6 +45,8 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { std::unordered_map>> _query_data; std::mutex _query_data_mutex; + + std::atomic _autoincrement_id; }; } // namespace brad From 213fd0f93fc3488dcbc7429bcc4dbe3d4946646a Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Wed, 10 Apr 2024 16:11:58 -0400 Subject: [PATCH 06/19] Replace fixed rowlist type with std::vector and std::vector, use placeholder value in query_data map --- cpp/server/brad_server_simple.cc | 34 +++++++++++--------------------- cpp/server/brad_server_simple.h | 12 +++++++---- cpp/server/brad_statement.cc | 27 +++++++++++++++---------- cpp/server/brad_statement.h | 12 ++++++++--- src/brad/front_end/front_end.py | 4 ++-- src/brad/row_list.py | 5 +---- 6 files changed, 48 insertions(+), 46 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 9333572d..376a710c 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -36,17 +36,6 @@ arrow::Result EncodeTransactionQuery( return Ticket{std::move(ticket_string)}; } -// arrow::Result EncodeTransactionQuery( -// const std::string &query, -// const std::string &transaction_id) { -// std::string transaction_query = transaction_id; -// transaction_query += ':'; -// transaction_query += query; -// ARROW_ASSIGN_OR_RAISE(auto ticket_string, -// CreateStatementQueryTicket(transaction_query)); -// return Ticket{std::move(ticket_string)}; -// } - std::string GetQueryTicket( const std::string &query, const std::string &transaction_id) { @@ -70,7 +59,6 @@ BradFlightSqlServer::~BradFlightSqlServer() = default; std::shared_ptr BradFlightSqlServer::Create() { - // std::shared_ptr result(new BradFlightSqlServer()); std::shared_ptr result = std::make_shared(); for (const auto &id_to_result : GetSqlInfoResultMap()) { @@ -82,7 +70,7 @@ std::shared_ptr void BradFlightSqlServer::InitWrapper( const std::string &host, int port, - std::function>(std::string)> handle_query) { + std::function(std::string)> handle_query) { auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie(); arrow::flight::FlightServerOptions options(location); @@ -106,28 +94,29 @@ arrow::Result> const FlightDescriptor &descriptor) { const std::string &query = command.query; - // ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); - // ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); const std::string &autoincrement_id = std::to_string(++_autoincrement_id); const std::string &query_ticket = GetQueryTicket(autoincrement_id, command.transaction_id); - // ARROW_ASSIGN_OR_RAISE(auto ticket, - // EncodeTransactionQuery(autoincrement_id, command.transaction_id)); ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query_ticket)); - std::vector> query_result; - { py::gil_scoped_acquire guard; + // TODO: define function to convert py::tuple to std::any + std::vector query_result; query_result = _handle_query(query); } + // TODO: remove + std::vector dummy_result; { std::scoped_lock guard(_query_data_mutex); - _query_data.insert({query_ticket, query_result}); - } + // TODO: replace with query_result + dummy_result.push_back(8); + _query_data.insert({query_ticket, dummy_result}); + } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query_result)); + // TODO: Replace with query_result + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(dummy_result)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); std::vector endpoints{ @@ -157,7 +146,6 @@ arrow::Result> const auto query_result = _query_data.at(query_ticket); std::shared_ptr statement; - // ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql)); ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); std::shared_ptr reader; diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 705687db..30ddeb84 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -11,6 +11,11 @@ #include #include +#include + +namespace py = pybind11; +using namespace pybind11::literals; + namespace brad { class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { @@ -23,7 +28,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { void InitWrapper(const std::string &host, int port, - std::function>(std::string)>); + std::function(std::string)>); void ServeWrapper(); @@ -40,10 +45,9 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; - // TODO: Create and reuse type for RowList - std::function>(std::string)> _handle_query; + std::function(std::string)> _handle_query; - std::unordered_map>> _query_data; + std::unordered_map> _query_data; std::mutex _query_data_mutex; std::atomic _autoincrement_id; diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 1f6ae58e..873773fa 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -16,6 +16,11 @@ #include #include +#include + +namespace py = pybind11; +using namespace pybind11::literals; + namespace brad { using arrow::internal::checked_cast; @@ -28,14 +33,14 @@ arrow::Result> BradStatement::Create( } arrow::Result> BradStatement::Create( - std::vector> query_result) { + std::vector query_result) { std::shared_ptr result( new BradStatement(query_result)); return result; } -BradStatement::BradStatement(std::vector> query_result) { - query_result_ = query_result; +BradStatement::BradStatement(std::vector query_result) { + query_result_ = query_result; } BradStatement::~BradStatement() { @@ -43,14 +48,16 @@ BradStatement::~BradStatement() { arrow::Result> BradStatement::GetSchema() const { std::vector> fields; - const auto row = query_result_[0]; - std::string field_type = typeid(std::get<0>(row)).name(); + // const auto row = query_result_[0]; + // std::string field_type = typeid(std::get<0>(row)).name(); - if (field_type == "i") { - fields.push_back(arrow::field("Field 1", arrow::int8())); - } else { - fields.push_back(arrow::field("Field 1", arrow::int16())); - } + // std::string field_type = typeid(row[0]).name(); + + // if (field_type == "i") { + // fields.push_back(arrow::field("Field 1", arrow::int8())); + // } else { + // fields.push_back(arrow::field("Field 1", arrow::int16())); + // } return arrow::schema(fields); } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index ba451bfb..9beb5ee7 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -1,11 +1,17 @@ #pragma once #include +#include #include #include #include +#include + +namespace py = pybind11; +using namespace pybind11::literals; + namespace brad { /// \brief Create an object ColumnMetadata using the column type and @@ -24,9 +30,9 @@ class BradStatement { const std::string& sql); static arrow::Result> Create( - const std::vector>); + const std::vector); - BradStatement(std::vector>); + BradStatement(std::vector); ~BradStatement(); @@ -38,7 +44,7 @@ class BradStatement { std::string* GetBradStmt() const; - std::vector> query_result_; + std::vector query_result_; private: std::string* stmt_; diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index ac3a4711..08267a02 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -45,7 +45,7 @@ from brad.routing.policy import RoutingPolicy from brad.routing.router import Router from brad.routing.tree_based.forest_policy import ForestPolicy -from brad.row_list import RowList, FixedRowList +from brad.row_list import RowList from brad.utils import log_verbose, create_custom_logger from brad.utils.counter import Counter from brad.utils.json_decimal_encoder import DecimalEncoder @@ -191,7 +191,7 @@ def __init__( self._is_stub_mode = self._config.stub_mode_path is not None - def _handle_query_from_flight_sql(self, query: str) -> FixedRowList: + def _handle_query_from_flight_sql(self, query: str) -> RowList: future = asyncio.run_coroutine_threadsafe( self._run_query_impl(self._flight_sql_server_session_id, query, {}), self._main_thread_loop diff --git a/src/brad/row_list.py b/src/brad/row_list.py index 04f8547b..a64ad067 100644 --- a/src/brad/row_list.py +++ b/src/brad/row_list.py @@ -1,7 +1,4 @@ from typing import Any, List, Tuple -RowList = List[Tuple[Any, ...]] - -# Note: pybind11 does not support the full generic std::any type -FixedRowList = List[Tuple[int]] +RowList = List[Tuple[Any, ...]] \ No newline at end of file From a88a7c13ddecda1697e290ef2a06d21864d4481e Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Thu, 11 Apr 2024 18:49:22 -0400 Subject: [PATCH 07/19] Add function to transform std::vector to std::vector> type --- cpp/server/brad_server_simple.cc | 26 ++++++++++++++++++-------- cpp/server/brad_server_simple.h | 2 +- cpp/server/brad_statement.cc | 4 ++-- cpp/server/brad_statement.h | 6 +++--- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 376a710c..83639916 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -53,6 +53,19 @@ arrow::Result> DecodeTransactionQuery( return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } +std::vector> TransformQueryResult( + std::vector query_result) { + std::vector> transformed_query_result; + for (const auto &tup : query_result) { + std::vector transformed_tup{}; + for (const auto &elt : tup) { + transformed_tup.push_back(elt); + } + transformed_query_result.push_back(transformed_tup); + } + return transformed_query_result; +} + BradFlightSqlServer::BradFlightSqlServer() = default; BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -99,24 +112,21 @@ arrow::Result> ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query_ticket)); + std::vector> transformed_query_result; + { py::gil_scoped_acquire guard; - // TODO: define function to convert py::tuple to std::any std::vector query_result; query_result = _handle_query(query); + transformed_query_result = TransformQueryResult(query_result); } - // TODO: remove - std::vector dummy_result; { std::scoped_lock guard(_query_data_mutex); - // TODO: replace with query_result - dummy_result.push_back(8); - _query_data.insert({query_ticket, dummy_result}); + _query_data.insert({query_ticket, transformed_query_result}); } - // TODO: Replace with query_result - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(dummy_result)); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); std::vector endpoints{ diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 30ddeb84..b5358cf4 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -47,7 +47,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { std::function(std::string)> _handle_query; - std::unordered_map> _query_data; + std::unordered_map>> _query_data; std::mutex _query_data_mutex; std::atomic _autoincrement_id; diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 873773fa..e1802794 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -33,13 +33,13 @@ arrow::Result> BradStatement::Create( } arrow::Result> BradStatement::Create( - std::vector query_result) { + std::vector> query_result) { std::shared_ptr result( new BradStatement(query_result)); return result; } -BradStatement::BradStatement(std::vector query_result) { +BradStatement::BradStatement(std::vector> query_result) { query_result_ = query_result; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 9beb5ee7..5f89e77c 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -30,9 +30,9 @@ class BradStatement { const std::string& sql); static arrow::Result> Create( - const std::vector); + const std::vector>); - BradStatement(std::vector); + BradStatement(std::vector>); ~BradStatement(); @@ -44,7 +44,7 @@ class BradStatement { std::string* GetBradStmt() const; - std::vector query_result_; + std::vector> query_result_; private: std::string* stmt_; From fc528a0f875cdae533078f2696a78862a5d8cb1d Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Thu, 11 Apr 2024 19:02:03 -0400 Subject: [PATCH 08/19] Get schema field types from query result --- cpp/server/brad_statement.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index e1802794..cdfee771 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -15,6 +15,7 @@ #include #include +#include #include @@ -48,16 +49,15 @@ BradStatement::~BradStatement() { arrow::Result> BradStatement::GetSchema() const { std::vector> fields; - // const auto row = query_result_[0]; - // std::string field_type = typeid(std::get<0>(row)).name(); - - // std::string field_type = typeid(row[0]).name(); - - // if (field_type == "i") { - // fields.push_back(arrow::field("Field 1", arrow::int8())); - // } else { - // fields.push_back(arrow::field("Field 1", arrow::int16())); - // } + const std::vector &row = query_result_[0]; + + for (const auto &elt : row) { + if (std::is_floating_point::value) { + fields.push_back(arrow::field("FLOAT FIELD", arrow::int16())); + } else { + fields.push_back(arrow::field("INT FIELD", arrow::int8())); + } + } return arrow::schema(fields); } From e8cc5fe63fa2ab9b8d5673d5902ac4539caf7b6a Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 13 Apr 2024 23:16:49 -0400 Subject: [PATCH 09/19] Fix name formatting for cpp class variables --- cpp/server/brad_server_simple.cc | 12 ++++++------ cpp/server/brad_server_simple.h | 8 ++++---- cpp/server/brad_statement.cc | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 83639916..369902b9 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -87,7 +87,7 @@ void BradFlightSqlServer::InitWrapper( auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie(); arrow::flight::FlightServerOptions options(location); - _handle_query = handle_query; + handle_query_ = handle_query; this->Init(options); } @@ -107,7 +107,7 @@ arrow::Result> const FlightDescriptor &descriptor) { const std::string &query = command.query; - const std::string &autoincrement_id = std::to_string(++_autoincrement_id); + const std::string &autoincrement_id = std::to_string(++autoincrement_id_); const std::string &query_ticket = GetQueryTicket(autoincrement_id, command.transaction_id); ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query_ticket)); @@ -117,13 +117,13 @@ arrow::Result> { py::gil_scoped_acquire guard; std::vector query_result; - query_result = _handle_query(query); + query_result = handle_query_(query); transformed_query_result = TransformQueryResult(query_result); } { - std::scoped_lock guard(_query_data_mutex); - _query_data.insert({query_ticket, transformed_query_result}); + std::scoped_lock guard(query_data_mutex_); + query_data_.insert({query_ticket, transformed_query_result}); } ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); @@ -153,7 +153,7 @@ arrow::Result> const std::string transaction_id = pair.second; const std::string &query_ticket = transaction_id + ':' + autoincrement_id; - const auto query_result = _query_data.at(query_ticket); + const auto query_result = query_data_.at(query_ticket); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index b5358cf4..dd770e81 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -45,12 +45,12 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; - std::function(std::string)> _handle_query; + std::function(std::string)> handle_query_; - std::unordered_map>> _query_data; - std::mutex _query_data_mutex; + std::unordered_map>> query_data_; + std::mutex query_data_mutex_; - std::atomic _autoincrement_id; + std::atomic autoincrement_id_; }; } // namespace brad diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index cdfee771..8d305900 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -53,7 +53,7 @@ arrow::Result> BradStatement::GetSchema() const { for (const auto &elt : row) { if (std::is_floating_point::value) { - fields.push_back(arrow::field("FLOAT FIELD", arrow::int16())); + fields.push_back(arrow::field("FLOAT FIELD", arrow::float16())); } else { fields.push_back(arrow::field("INT FIELD", arrow::int8())); } From 4a3a3e23209203fb7e26924ec3b8f6bbe77b845c Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sun, 14 Apr 2024 20:16:43 -0400 Subject: [PATCH 10/19] Remove unused FetchResult, use efficient concurrent hash table --- cpp/CMakeLists.txt | 11 +++++++- cpp/server/brad_server_simple.cc | 4 +-- cpp/server/brad_server_simple.h | 4 ++- cpp/server/brad_statement.cc | 33 ----------------------- cpp/server/brad_statement.h | 2 -- cpp/server/brad_statement_batch_reader.cc | 3 ++- 6 files changed, 17 insertions(+), 40 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6465e461..a5b8f92f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -16,6 +16,14 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) +# Concurrent hash table +FetchContent_Declare( + libcuckoo + GIT_REPOSITORY https://github.com/efficient/libcuckoo.git + GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 +) +FetchContent_MakeAvailable(libcuckoo) + add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc @@ -35,7 +43,8 @@ pybind11_add_module(pybind_brad_server pybind/brad_server.cc target_link_libraries(pybind_brad_server PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared - PRIVATE ArrowFlightSql::arrow_flight_sql_shared) + PRIVATE ArrowFlightSql::arrow_flight_sql_shared + PUBLIC libcuckoo) add_executable(flight_sql_example_client flight_sql_example_client.cc) target_link_libraries(flight_sql_example_client diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 369902b9..3f42e6ed 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -123,7 +123,7 @@ arrow::Result> { std::scoped_lock guard(query_data_mutex_); - query_data_.insert({query_ticket, transformed_query_result}); + query_data_.insert(query_ticket, transformed_query_result); } ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); @@ -153,7 +153,7 @@ arrow::Result> const std::string transaction_id = pair.second; const std::string &query_ticket = transaction_id + ':' + autoincrement_id; - const auto query_result = query_data_.at(query_ticket); + const auto query_result = query_data_.find(query_ticket); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index dd770e81..dea03bc8 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -11,6 +11,8 @@ #include #include +#include "libcuckoo/cuckoohash_map.hh" + #include namespace py = pybind11; @@ -47,7 +49,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { std::function(std::string)> handle_query_; - std::unordered_map>> query_data_; + libcuckoo::cuckoohash_map>> query_data_; std::mutex query_data_mutex_; std::atomic autoincrement_id_; diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 8d305900..b30fe3db 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -62,39 +62,6 @@ arrow::Result> BradStatement::GetSchema() const { return arrow::schema(fields); } -arrow::Result> BradStatement::FetchResult() { - arrow::Int8Builder int8builder; - int8_t days_raw[5] = {1, 12, 17, 23, 28}; - ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw, 5)); - std::shared_ptr days; - ARROW_ASSIGN_OR_RAISE(days, int8builder.Finish()); - - int8_t months_raw[5] = {1, 3, 5, 7, 1}; - ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw, 5)); - std::shared_ptr months; - ARROW_ASSIGN_OR_RAISE(months, int8builder.Finish()); - - arrow::Int16Builder int16builder; - int16_t years_raw[5] = {1990, 2000, 1995, 2000, 1995}; - ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw, 5)); - std::shared_ptr years; - ARROW_ASSIGN_OR_RAISE(years, int16builder.Finish()); - - std::shared_ptr record_batch; - - arrow::Result> result = GetSchema(); - - if (result.ok()) { - std::shared_ptr schema = result.ValueOrDie(); - record_batch = arrow::RecordBatch::Make(schema, - days->length(), - {days, months, years}); - return record_batch; - } - - return arrow::Status::OK(); -} - std::string* BradStatement::GetBradStmt() const { return stmt_; } } // namespace brad diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 5f89e77c..45b4a1e0 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -40,8 +40,6 @@ class BradStatement { /// \return The resulting Schema. arrow::Result> GetSchema() const; - arrow::Result> FetchResult(); - std::string* GetBradStmt() const; std::vector> query_result_; diff --git a/cpp/server/brad_statement_batch_reader.cc b/cpp/server/brad_statement_batch_reader.cc index 16ef38cd..56bb5e47 100644 --- a/cpp/server/brad_statement_batch_reader.cc +++ b/cpp/server/brad_statement_batch_reader.cc @@ -2,6 +2,7 @@ #include #include "brad_statement.h" +#include namespace brad { @@ -42,7 +43,7 @@ arrow::Status BradStatementBatchReader::ReadNext(std::shared_ptrFetchResult()); + // ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult()); already_executed_ = true; return arrow::Status::OK(); } From b88e1b71464a8f76e936c5905aae717ec2ea6297 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Mon, 15 Apr 2024 18:02:31 -0400 Subject: [PATCH 11/19] Fix type checking for query result to make schema --- cpp/server/brad_server_simple.cc | 8 +++++++- cpp/server/brad_statement.cc | 7 +++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 3f42e6ed..805d7c7e 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -59,7 +59,13 @@ std::vector> TransformQueryResult( for (const auto &tup : query_result) { std::vector transformed_tup{}; for (const auto &elt : tup) { - transformed_tup.push_back(elt); + if (py::isinstance(elt)) { + transformed_tup.push_back(std::make_any(py::cast(elt))); + } else if (py::isinstance(elt)) { + transformed_tup.push_back(std::make_any(py::cast(elt))); + } else { + transformed_tup.push_back(std::make_any(py::cast(elt))); + } } transformed_query_result.push_back(transformed_tup); } diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index b30fe3db..d64d863b 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -52,10 +52,13 @@ arrow::Result> BradStatement::GetSchema() const { const std::vector &row = query_result_[0]; for (const auto &elt : row) { - if (std::is_floating_point::value) { + std::string elt_type = elt.type().name(); + if (elt_type == "i") { + fields.push_back(arrow::field("INT FIELD", arrow::int8())); + } else if (elt_type == "f") { fields.push_back(arrow::field("FLOAT FIELD", arrow::float16())); } else { - fields.push_back(arrow::field("INT FIELD", arrow::int8())); + fields.push_back(arrow::field("STRING FIELD", arrow::utf8())); } } From 2a4bbffa27c99d03c5d39bfdbccbffea1a569495 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Mon, 15 Apr 2024 23:37:15 -0400 Subject: [PATCH 12/19] Fix formatting --- cpp/CMakeLists.txt | 1 - cpp/server/brad_server_simple.cc | 16 +++++++--------- cpp/server/brad_statement.cc | 12 ++++-------- cpp/server/brad_statement_batch_reader.cc | 1 - src/brad/front_end/flight_sql_server.py | 3 ++- src/brad/front_end/front_end.py | 2 +- src/brad/row_list.py | 2 +- 7 files changed, 15 insertions(+), 22 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a5b8f92f..4c74b7be 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -39,7 +39,6 @@ pybind11_add_module(pybind_brad_server pybind/brad_server.cc server/brad_statement.cc server/brad_tables_schema_batch_reader.cc) - target_link_libraries(pybind_brad_server PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 805d7c7e..d2007b19 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -29,6 +28,12 @@ using arrow::internal::checked_cast; using namespace arrow::flight; using namespace arrow::flight::sql; +std::string GetQueryTicket( + const std::string &autoincrement_id, + const std::string &transaction_id) { + return transaction_id + ':' + autoincrement_id; +} + arrow::Result EncodeTransactionQuery( const std::string &query_ticket) { ARROW_ASSIGN_OR_RAISE(auto ticket_string, @@ -36,12 +41,6 @@ arrow::Result EncodeTransactionQuery( return Ticket{std::move(ticket_string)}; } -std::string GetQueryTicket( - const std::string &query, - const std::string &transaction_id) { - return transaction_id + ':' + query; -} - arrow::Result> DecodeTransactionQuery( const std::string &ticket) { auto divider = ticket.find(':'); @@ -122,8 +121,7 @@ arrow::Result> { py::gil_scoped_acquire guard; - std::vector query_result; - query_result = handle_query_(query); + std::vector query_result = handle_query_(query); transformed_query_result = TransformQueryResult(query_result); } diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index d64d863b..e8d808d1 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -13,10 +13,6 @@ #include #include -#include -#include -#include - #include namespace py = pybind11; @@ -51,11 +47,11 @@ arrow::Result> BradStatement::GetSchema() const { std::vector> fields; const std::vector &row = query_result_[0]; - for (const auto &elt : row) { - std::string elt_type = elt.type().name(); - if (elt_type == "i") { + for (const auto &field : row) { + std::string field_type = field.type().name(); + if (field_type == "i") { fields.push_back(arrow::field("INT FIELD", arrow::int8())); - } else if (elt_type == "f") { + } else if (field_type == "f") { fields.push_back(arrow::field("FLOAT FIELD", arrow::float16())); } else { fields.push_back(arrow::field("STRING FIELD", arrow::utf8())); diff --git a/cpp/server/brad_statement_batch_reader.cc b/cpp/server/brad_statement_batch_reader.cc index 56bb5e47..9627ccda 100644 --- a/cpp/server/brad_statement_batch_reader.cc +++ b/cpp/server/brad_statement_batch_reader.cc @@ -2,7 +2,6 @@ #include #include "brad_statement.h" -#include namespace brad { diff --git a/src/brad/front_end/flight_sql_server.py b/src/brad/front_end/flight_sql_server.py index 279cd79d..22152e8e 100644 --- a/src/brad/front_end/flight_sql_server.py +++ b/src/brad/front_end/flight_sql_server.py @@ -1,5 +1,6 @@ import logging import threading +from typing import Callable # pylint: disable-next=import-error,no-name-in-module,unused-import import brad.native.pybind_brad_server as brad_server @@ -8,7 +9,7 @@ class BradFlightSqlServer: - def __init__(self, host: str, port: int, callback) -> None: + def __init__(self, host: str, port: int, callback: Callable) -> None: self._flight_sql_server = brad_server.BradFlightSqlServer() self._flight_sql_server.init(host, port, callback) self._thread = threading.Thread(name="BradFlightSqlServer", target=self._serve) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 08267a02..5771b300 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -59,6 +59,7 @@ LINESEP = "\n".encode() + class BradFrontEnd(BradInterface): @staticmethod def native_server_is_supported() -> bool: @@ -252,7 +253,6 @@ async def _run_setup(self) -> None: if not self._is_stub_mode: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) - self._watchdog.start(self._main_thread_loop) self._ping_watchdog_task = asyncio.create_task(self._ping_watchdog()) diff --git a/src/brad/row_list.py b/src/brad/row_list.py index a64ad067..f070549a 100644 --- a/src/brad/row_list.py +++ b/src/brad/row_list.py @@ -1,4 +1,4 @@ from typing import Any, List, Tuple -RowList = List[Tuple[Any, ...]] \ No newline at end of file +RowList = List[Tuple[Any, ...]] From 95ef82809600a39d53d2213f83f217a7d0e40725 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Mon, 15 Apr 2024 23:47:42 -0400 Subject: [PATCH 13/19] Fix formatting (cont) --- cpp/CMakeLists.txt | 2 +- cpp/server/brad_server_simple.cc | 24 +++++++++--------------- cpp/server/brad_statement.cc | 5 ----- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4c74b7be..4a68f5a7 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -67,4 +67,4 @@ target_link_libraries(brad_front_end PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared PRIVATE ArrowFlightSql::arrow_flight_sql_shared - gflags) \ No newline at end of file + gflags) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index d2007b19..3a5b7b96 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include "brad_sql_info.h" @@ -18,10 +16,6 @@ #include #include -#include - -namespace py = pybind11; - namespace brad { using arrow::internal::checked_cast; @@ -55,18 +49,18 @@ arrow::Result> DecodeTransactionQuery( std::vector> TransformQueryResult( std::vector query_result) { std::vector> transformed_query_result; - for (const auto &tup : query_result) { - std::vector transformed_tup{}; - for (const auto &elt : tup) { - if (py::isinstance(elt)) { - transformed_tup.push_back(std::make_any(py::cast(elt))); - } else if (py::isinstance(elt)) { - transformed_tup.push_back(std::make_any(py::cast(elt))); + for (const auto &row : query_result) { + std::vector transformed_row{}; + for (const auto &field : row) { + if (py::isinstance(field)) { + transformed_row.push_back(std::make_any(py::cast(field))); + } else if (py::isinstance(field)) { + transformed_row.push_back(std::make_any(py::cast(field))); } else { - transformed_tup.push_back(std::make_any(py::cast(elt))); + transformed_row.push_back(std::make_any(py::cast(field))); } } - transformed_query_result.push_back(transformed_tup); + transformed_query_result.push_back(transformed_row); } return transformed_query_result; } diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index e8d808d1..40e18479 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -13,11 +13,6 @@ #include #include -#include - -namespace py = pybind11; -using namespace pybind11::literals; - namespace brad { using arrow::internal::checked_cast; From a76dc0d1f6ed685191e20a169189d37597489a09 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Tue, 16 Apr 2024 00:00:42 -0400 Subject: [PATCH 14/19] Address Python code check --- src/brad/front_end/front_end.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 5771b300..ad4259d3 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -90,13 +90,18 @@ def __init__( from brad.front_end.flight_sql_server import BradFlightSqlServer self._flight_sql_server: Optional[BradFlightSqlServer] = ( - BradFlightSqlServer(host="0.0.0.0", - port=31337, - callback=self._handle_query_from_flight_sql) + BradFlightSqlServer( + host="0.0.0.0", + port=31337, + callback=self._handle_query_from_flight_sql, + ) ) + self._flight_sql_server_session_id = None else: self._flight_sql_server = None + self._main_thread_loop = None + self._fe_index = fe_index self._config = config self._schema_name = schema_name @@ -195,7 +200,7 @@ def __init__( def _handle_query_from_flight_sql(self, query: str) -> RowList: future = asyncio.run_coroutine_threadsafe( self._run_query_impl(self._flight_sql_server_session_id, query, {}), - self._main_thread_loop + self._main_thread_loop, ) row_result = future.result() From 7d85b674a5c7976cace8e38f935fa3aa620b0306 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Tue, 16 Apr 2024 10:56:26 -0400 Subject: [PATCH 15/19] Address python type check errors --- src/brad/front_end/front_end.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index ad4259d3..5af11477 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -96,11 +96,11 @@ def __init__( callback=self._handle_query_from_flight_sql, ) ) - self._flight_sql_server_session_id = None + self._flight_sql_server_session_id: Optional[SessionId] = None else: self._flight_sql_server = None - self._main_thread_loop = None + self._main_thread_loop: Optional[AbstractEventLoop] = None self._fe_index = fe_index self._config = config @@ -198,6 +198,8 @@ def __init__( self._is_stub_mode = self._config.stub_mode_path is not None def _handle_query_from_flight_sql(self, query: str) -> RowList: + assert self._flight_sql_server_session_id is not None + future = asyncio.run_coroutine_threadsafe( self._run_query_impl(self._flight_sql_server_session_id, query, {}), self._main_thread_loop, From 57a4961d07e634c69a7d10cb0b04e64d00c55e5e Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Fri, 19 Apr 2024 16:32:01 -0400 Subject: [PATCH 16/19] Address PR comments --- cpp/CMakeLists.txt | 8 --- cpp/server/brad_server_simple.cc | 23 ++++---- cpp/server/brad_server_simple.h | 5 +- cpp/server/brad_statement.cc | 66 +++++++++++++++++++---- cpp/server/brad_statement.h | 8 ++- cpp/server/brad_statement_batch_reader.cc | 5 +- cpp/third_party/CMakeLists.txt | 8 ++- src/brad/front_end/front_end.py | 2 +- 8 files changed, 90 insertions(+), 35 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4a68f5a7..cb82fa75 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -16,14 +16,6 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) -# Concurrent hash table -FetchContent_Declare( - libcuckoo - GIT_REPOSITORY https://github.com/efficient/libcuckoo.git - GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 -) -FetchContent_MakeAvailable(libcuckoo) - add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 3a5b7b96..cb090030 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -65,7 +65,7 @@ std::vector> TransformQueryResult( return transformed_query_result; } -BradFlightSqlServer::BradFlightSqlServer() = default; +BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -119,12 +119,9 @@ arrow::Result> transformed_query_result = TransformQueryResult(query_result); } - { - std::scoped_lock guard(query_data_mutex_); - query_data_.insert(query_ticket, transformed_query_result); - } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); + query_data_.insert(query_ticket, statement); + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); std::vector endpoints{ @@ -151,13 +148,19 @@ arrow::Result> const std::string transaction_id = pair.second; const std::string &query_ticket = transaction_id + ':' + autoincrement_id; - const auto query_result = query_data_.find(query_ticket); - std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); + std::shared_ptr result; + const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) { + result = qr; + return true; + }); + + if (!found) { + return arrow::Status::Invalid("Invalid ticket."); + } std::shared_ptr reader; - ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement)); + ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result)); return std::make_unique(reader); } diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index dea03bc8..cbffe9e2 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -9,6 +9,7 @@ #include #include +#include "brad_statement.h" #include #include "libcuckoo/cuckoohash_map.hh" @@ -47,10 +48,10 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; + private: std::function(std::string)> handle_query_; - libcuckoo::cuckoohash_map>> query_data_; - std::mutex query_data_mutex_; + libcuckoo::cuckoohash_map> query_data_; std::atomic autoincrement_id_; }; diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 40e18479..3d66c6ef 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -27,33 +27,81 @@ arrow::Result> BradStatement::Create( arrow::Result> BradStatement::Create( std::vector> query_result) { std::shared_ptr result( - new BradStatement(query_result)); + std::make_shared(query_result)); return result; } -BradStatement::BradStatement(std::vector> query_result) { - query_result_ = query_result; -} +BradStatement::BradStatement(std::vector> query_result) : + query_result_(std::move(query_result)) {} BradStatement::~BradStatement() { } -arrow::Result> BradStatement::GetSchema() const { +arrow::Result> BradStatement::GetSchema() { + if (schema_) { + return schema_; + } + std::vector> fields; const std::vector &row = query_result_[0]; + int counter = 0; for (const auto &field : row) { std::string field_type = field.type().name(); if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD", arrow::int8())); + fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD", arrow::float16())); + fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); } else { - fields.push_back(arrow::field("STRING FIELD", arrow::utf8())); + fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + } + } + + schema_ = arrow::schema(fields); + return schema_; +} + +arrow::Result> BradStatement::FetchResult() { + std::shared_ptr schema = GetSchema().ValueOrDie(); + + const int num_rows = query_result_.size(); + + std::vector> columns; + for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { + const auto &field = schema->fields()[field_ix]; + if (field->type() == arrow::int8()) { + arrow::Int8Builder int8builder; + int8_t values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::float32()) { + arrow::FloatBuilder floatbuilder; + float values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::utf8()) { } } - return arrow::schema(fields); + std::shared_ptr record_batch = + arrow::RecordBatch::Make(schema, + num_rows, + columns); + return record_batch; } std::string* BradStatement::GetBradStmt() const { return stmt_; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 45b4a1e0..5c62dfea 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -38,13 +38,17 @@ class BradStatement { /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema() const; + arrow::Result> GetSchema(); + + arrow::Result> FetchResult(); std::string* GetBradStmt() const; + private: std::vector> query_result_; - private: + std::shared_ptr schema_; + std::string* stmt_; BradStatement(std::string* stmt) : stmt_(stmt) {} diff --git a/cpp/server/brad_statement_batch_reader.cc b/cpp/server/brad_statement_batch_reader.cc index 9627ccda..48c9d5f2 100644 --- a/cpp/server/brad_statement_batch_reader.cc +++ b/cpp/server/brad_statement_batch_reader.cc @@ -13,7 +13,8 @@ BradStatementBatchReader::BradStatementBatchReader( std::shared_ptr statement, std::shared_ptr schema) : statement_(std::move(statement)), - schema_(std::move(schema)) {} + schema_(std::move(schema)), + already_executed_(false) {} arrow::Result> BradStatementBatchReader::Create( @@ -42,7 +43,7 @@ arrow::Status BradStatementBatchReader::ReadNext(std::shared_ptrFetchResult()); + ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult()); already_executed_ = true; return arrow::Status::OK(); } diff --git a/cpp/third_party/CMakeLists.txt b/cpp/third_party/CMakeLists.txt index 002e0b50..8ba1d026 100644 --- a/cpp/third_party/CMakeLists.txt +++ b/cpp/third_party/CMakeLists.txt @@ -12,4 +12,10 @@ FetchContent_Declare( GIT_TAG v2.2.2 ) -FetchContent_MakeAvailable(pybind11 gflags) +FetchContent_Declare( + libcuckoo + GIT_REPOSITORY https://github.com/efficient/libcuckoo.git + GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 +) + +FetchContent_MakeAvailable(pybind11 gflags libcuckoo) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 5af11477..178ea9e6 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -100,7 +100,7 @@ def __init__( else: self._flight_sql_server = None - self._main_thread_loop: Optional[AbstractEventLoop] = None + self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None self._fe_index = fe_index self._config = config From dcc081a576eb1cd786a0a8a5c192a5e44f17e298 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Fri, 19 Apr 2024 16:36:38 -0400 Subject: [PATCH 17/19] Fix mypy type check --- src/brad/front_end/front_end.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 178ea9e6..f460e2bf 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -199,6 +199,7 @@ def __init__( def _handle_query_from_flight_sql(self, query: str) -> RowList: assert self._flight_sql_server_session_id is not None + assert self._main_thread_loop is not None future = asyncio.run_coroutine_threadsafe( self._run_query_impl(self._flight_sql_server_session_id, query, {}), From b54818f6d629f0393e2a316dff1457d3b7f742f8 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 20 Apr 2024 18:30:17 -0400 Subject: [PATCH 18/19] Address PR comments, add in Arrow conversion for string types --- cpp/server/brad_server_simple.h | 4 ---- cpp/server/brad_statement.cc | 37 ++++++++++++++++++++++----------- cpp/server/brad_statement.h | 9 ++------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index cbffe9e2..48056a04 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -3,10 +3,6 @@ #include #include #include -#include -#include -#include -#include #include #include "brad_statement.h" diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 3d66c6ef..e9ce1588 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -37,23 +37,26 @@ BradStatement::BradStatement(std::vector> query_result) : BradStatement::~BradStatement() { } -arrow::Result> BradStatement::GetSchema() { +arrow::Result> BradStatement::GetSchema() const { if (schema_) { return schema_; } std::vector> fields; - const std::vector &row = query_result_[0]; - - int counter = 0; - for (const auto &field : row) { - std::string field_type = field.type().name(); - if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); - } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); - } else { - fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + + if (query_result_.size() > 0) { + const std::vector &row = query_result_[0]; + + int counter = 0; + for (const auto &field : row) { + std::string field_type = field.type().name(); + if (field_type == "i") { + fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); + } else if (field_type == "f") { + fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); + } else { + fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + } } } @@ -67,6 +70,8 @@ arrow::Result> BradStatement::FetchResult() const int num_rows = query_result_.size(); std::vector> columns; + columns.reserve(schema->num_fields()); + for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { const auto &field = schema->fields()[field_ix]; if (field->type() == arrow::int8()) { @@ -94,6 +99,14 @@ arrow::Result> BradStatement::FetchResult() columns.push_back(values); } else if (field->type() == arrow::utf8()) { + arrow::StringBuilder stringbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::string* str = std::any_cast(&(query_result_[row_ix][field_ix])); + ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); + } + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); } } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 5c62dfea..6f13bc70 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -7,11 +7,6 @@ #include #include -#include - -namespace py = pybind11; -using namespace pybind11::literals; - namespace brad { /// \brief Create an object ColumnMetadata using the column type and @@ -38,7 +33,7 @@ class BradStatement { /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema(); + arrow::Result> GetSchema() const; arrow::Result> FetchResult(); @@ -47,7 +42,7 @@ class BradStatement { private: std::vector> query_result_; - std::shared_ptr schema_; + mutable std::shared_ptr schema_; std::string* stmt_; From cc909f946a11c88b52a51c4e2d11ae398c49ae61 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sun, 21 Apr 2024 08:59:39 -0400 Subject: [PATCH 19/19] Add cpp library imports --- cpp/server/brad_server_simple.h | 3 +++ cpp/server/brad_statement.h | 1 + 2 files changed, 4 insertions(+) diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 48056a04..d2e0c186 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -1,8 +1,11 @@ #pragma once +#include #include +#include #include #include +#include #include #include "brad_statement.h" diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 6f13bc70..b3dba2cc 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include