Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophie Zhang committed Apr 19, 2024
1 parent 7d85b67 commit 57a4961
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 35 deletions.
8 changes: 0 additions & 8 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ find_package(Boost REQUIRED)

add_subdirectory(third_party)

# Concurrent hash table
FetchContent_Declare(
libcuckoo
GIT_REPOSITORY https://github.com/efficient/libcuckoo.git
GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058
)
FetchContent_MakeAvailable(libcuckoo)

add_library(sqlite_server_lib OBJECT
sqlite_server/sqlite_server.cc
sqlite_server/sqlite_sql_info.cc
Expand Down
23 changes: 13 additions & 10 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ std::vector<std::vector<std::any>> TransformQueryResult(
return transformed_query_result;
}

BradFlightSqlServer::BradFlightSqlServer() = default;
BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {}

BradFlightSqlServer::~BradFlightSqlServer() = default;

Expand Down Expand Up @@ -119,12 +119,9 @@ arrow::Result<std::unique_ptr<FlightInfo>>
transformed_query_result = TransformQueryResult(query_result);
}

{
std::scoped_lock guard(query_data_mutex_);
query_data_.insert(query_ticket, transformed_query_result);
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result));
query_data_.insert(query_ticket, statement);

ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema());

std::vector<FlightEndpoint> endpoints{
Expand All @@ -151,13 +148,19 @@ arrow::Result<std::unique_ptr<FlightDataStream>>
const std::string transaction_id = pair.second;

const std::string &query_ticket = transaction_id + ':' + autoincrement_id;
const auto query_result = query_data_.find(query_ticket);

std::shared_ptr<BradStatement> statement;
ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(query_result));
std::shared_ptr<BradStatement> result;
const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) {
result = qr;
return true;
});

if (!found) {
return arrow::Status::Invalid("Invalid ticket.");
}

std::shared_ptr<BradStatementBatchReader> reader;
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement));
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result));

return std::make_unique<RecordBatchStream>(reader);
}
Expand Down
5 changes: 3 additions & 2 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <mutex>

#include <arrow/flight/sql/server.h>
#include "brad_statement.h"
#include <arrow/result.h>

#include "libcuckoo/cuckoohash_map.hh"
Expand Down Expand Up @@ -47,10 +48,10 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
const arrow::flight::ServerCallContext &context,
const arrow::flight::sql::StatementQueryTicket &command) override;

private:
std::function<std::vector<py::tuple>(std::string)> handle_query_;

libcuckoo::cuckoohash_map<std::string, std::vector<std::vector<std::any>>> query_data_;
std::mutex query_data_mutex_;
libcuckoo::cuckoohash_map<std::string, std::shared_ptr<BradStatement>> query_data_;

std::atomic<uint64_t> autoincrement_id_;
};
Expand Down
66 changes: 57 additions & 9 deletions cpp/server/brad_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,81 @@ arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
std::vector<std::vector<std::any>> query_result) {
std::shared_ptr<BradStatement> result(
new BradStatement(query_result));
std::make_shared<BradStatement>(query_result));
return result;
}

BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) {
query_result_ = query_result;
}
BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) :
query_result_(std::move(query_result)) {}

BradStatement::~BradStatement() {
}

arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() const {
arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() {
if (schema_) {
return schema_;
}

std::vector<std::shared_ptr<arrow::Field>> fields;
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD", arrow::int8()));
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD", arrow::float16()));
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD", arrow::utf8()));
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));
}
}

schema_ = arrow::schema(fields);
return schema_;
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult() {
std::shared_ptr<arrow::Schema> schema = GetSchema().ValueOrDie();

const int num_rows = query_result_.size();

std::vector<std::shared_ptr<arrow::Array>> columns;
for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) {
const auto &field = schema->fields()[field_ix];
if (field->type() == arrow::int8()) {
arrow::Int8Builder int8builder;
int8_t values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<int>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::float32()) {
arrow::FloatBuilder floatbuilder;
float values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<float>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::utf8()) {
}
}

return arrow::schema(fields);
std::shared_ptr<arrow::RecordBatch> record_batch =
arrow::RecordBatch::Make(schema,
num_rows,
columns);
return record_batch;
}

std::string* BradStatement::GetBradStmt() const { return stmt_; }
Expand Down
8 changes: 6 additions & 2 deletions cpp/server/brad_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ class BradStatement {

/// \brief Creates an Arrow Schema based on the results of this statement.
/// \return The resulting Schema.
arrow::Result<std::shared_ptr<arrow::Schema>> GetSchema() const;
arrow::Result<std::shared_ptr<arrow::Schema>> GetSchema();

arrow::Result<std::shared_ptr<arrow::RecordBatch>> FetchResult();

std::string* GetBradStmt() const;

private:
std::vector<std::vector<std::any>> query_result_;

private:
std::shared_ptr<arrow::Schema> schema_;

std::string* stmt_;

BradStatement(std::string* stmt) : stmt_(stmt) {}
Expand Down
5 changes: 3 additions & 2 deletions cpp/server/brad_statement_batch_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ BradStatementBatchReader::BradStatementBatchReader(
std::shared_ptr<BradStatement> statement,
std::shared_ptr<arrow::Schema> schema)
: statement_(std::move(statement)),
schema_(std::move(schema)) {}
schema_(std::move(schema)),
already_executed_(false) {}

arrow::Result<std::shared_ptr<BradStatementBatchReader>>
BradStatementBatchReader::Create(
Expand Down Expand Up @@ -42,7 +43,7 @@ arrow::Status BradStatementBatchReader::ReadNext(std::shared_ptr<arrow::RecordBa
return arrow::Status::OK();
}

// ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult());
ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult());
already_executed_ = true;
return arrow::Status::OK();
}
Expand Down
8 changes: 7 additions & 1 deletion cpp/third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ FetchContent_Declare(
GIT_TAG v2.2.2
)

FetchContent_MakeAvailable(pybind11 gflags)
FetchContent_Declare(
libcuckoo
GIT_REPOSITORY https://github.com/efficient/libcuckoo.git
GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058
)

FetchContent_MakeAvailable(pybind11 gflags libcuckoo)
2 changes: 1 addition & 1 deletion src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
else:
self._flight_sql_server = None

self._main_thread_loop: Optional[AbstractEventLoop] = None
self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None

self._fe_index = fe_index
self._config = config
Expand Down

0 comments on commit 57a4961

Please sign in to comment.