diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4a68f5a7..cb82fa75 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -16,14 +16,6 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) -# Concurrent hash table -FetchContent_Declare( - libcuckoo - GIT_REPOSITORY https://github.com/efficient/libcuckoo.git - GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 -) -FetchContent_MakeAvailable(libcuckoo) - add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 3a5b7b96..cb090030 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -65,7 +65,7 @@ std::vector> TransformQueryResult( return transformed_query_result; } -BradFlightSqlServer::BradFlightSqlServer() = default; +BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -119,12 +119,9 @@ arrow::Result> transformed_query_result = TransformQueryResult(query_result); } - { - std::scoped_lock guard(query_data_mutex_); - query_data_.insert(query_ticket, transformed_query_result); - } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); + query_data_.insert(query_ticket, statement); + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); std::vector endpoints{ @@ -151,13 +148,19 @@ arrow::Result> const std::string transaction_id = pair.second; const std::string &query_ticket = transaction_id + ':' + autoincrement_id; - const auto query_result = query_data_.find(query_ticket); - std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result)); + std::shared_ptr result; + const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) { + result = qr; + return true; + }); + + if (!found) { + return arrow::Status::Invalid("Invalid ticket."); + } std::shared_ptr reader; - ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement)); + ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result)); return std::make_unique(reader); } diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index dea03bc8..cbffe9e2 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -9,6 +9,7 @@ #include #include +#include "brad_statement.h" #include #include "libcuckoo/cuckoohash_map.hh" @@ -47,10 +48,10 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; + private: std::function(std::string)> handle_query_; - libcuckoo::cuckoohash_map>> query_data_; - std::mutex query_data_mutex_; + libcuckoo::cuckoohash_map> query_data_; std::atomic autoincrement_id_; }; diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 40e18479..3d66c6ef 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -27,33 +27,81 @@ arrow::Result> BradStatement::Create( arrow::Result> BradStatement::Create( std::vector> query_result) { std::shared_ptr result( - new BradStatement(query_result)); + std::make_shared(query_result)); return result; } -BradStatement::BradStatement(std::vector> query_result) { - query_result_ = query_result; -} +BradStatement::BradStatement(std::vector> query_result) : + query_result_(std::move(query_result)) {} BradStatement::~BradStatement() { } -arrow::Result> BradStatement::GetSchema() const { +arrow::Result> BradStatement::GetSchema() { + if (schema_) { + return schema_; + } + std::vector> fields; const std::vector &row = query_result_[0]; + int counter = 0; for (const auto &field : row) { std::string field_type = field.type().name(); if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD", arrow::int8())); + fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD", arrow::float16())); + fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); } else { - fields.push_back(arrow::field("STRING FIELD", arrow::utf8())); + fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + } + } + + schema_ = arrow::schema(fields); + return schema_; +} + +arrow::Result> BradStatement::FetchResult() { + std::shared_ptr schema = GetSchema().ValueOrDie(); + + const int num_rows = query_result_.size(); + + std::vector> columns; + for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { + const auto &field = schema->fields()[field_ix]; + if (field->type() == arrow::int8()) { + arrow::Int8Builder int8builder; + int8_t values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::float32()) { + arrow::FloatBuilder floatbuilder; + float values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::utf8()) { } } - return arrow::schema(fields); + std::shared_ptr record_batch = + arrow::RecordBatch::Make(schema, + num_rows, + columns); + return record_batch; } std::string* BradStatement::GetBradStmt() const { return stmt_; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 45b4a1e0..5c62dfea 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -38,13 +38,17 @@ class BradStatement { /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema() const; + arrow::Result> GetSchema(); + + arrow::Result> FetchResult(); std::string* GetBradStmt() const; + private: std::vector> query_result_; - private: + std::shared_ptr schema_; + std::string* stmt_; BradStatement(std::string* stmt) : stmt_(stmt) {} diff --git a/cpp/server/brad_statement_batch_reader.cc b/cpp/server/brad_statement_batch_reader.cc index 9627ccda..48c9d5f2 100644 --- a/cpp/server/brad_statement_batch_reader.cc +++ b/cpp/server/brad_statement_batch_reader.cc @@ -13,7 +13,8 @@ BradStatementBatchReader::BradStatementBatchReader( std::shared_ptr statement, std::shared_ptr schema) : statement_(std::move(statement)), - schema_(std::move(schema)) {} + schema_(std::move(schema)), + already_executed_(false) {} arrow::Result> BradStatementBatchReader::Create( @@ -42,7 +43,7 @@ arrow::Status BradStatementBatchReader::ReadNext(std::shared_ptrFetchResult()); + ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult()); already_executed_ = true; return arrow::Status::OK(); } diff --git a/cpp/third_party/CMakeLists.txt b/cpp/third_party/CMakeLists.txt index 002e0b50..8ba1d026 100644 --- a/cpp/third_party/CMakeLists.txt +++ b/cpp/third_party/CMakeLists.txt @@ -12,4 +12,10 @@ FetchContent_Declare( GIT_TAG v2.2.2 ) -FetchContent_MakeAvailable(pybind11 gflags) +FetchContent_Declare( + libcuckoo + GIT_REPOSITORY https://github.com/efficient/libcuckoo.git + GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 +) + +FetchContent_MakeAvailable(pybind11 gflags libcuckoo) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 5af11477..178ea9e6 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -100,7 +100,7 @@ def __init__( else: self._flight_sql_server = None - self._main_thread_loop: Optional[AbstractEventLoop] = None + self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None self._fe_index = fe_index self._config = config