From b4a54a9c8e2acd1de415c39a421a5359f53ab63b Mon Sep 17 00:00:00 2001 From: Sophie Zhang <88999452+sopzha@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:39:50 -0400 Subject: [PATCH] Create RecordBatch in BradStatement from query result and schema exposed from underlying connections (#502) Co-authored-by: Sophie Zhang --- cpp/server/brad_server_simple.cc | 124 ++++++++++++++++++++++++++----- cpp/server/brad_server_simple.h | 1 + cpp/server/brad_statement.cc | 89 +++------------------- cpp/server/brad_statement.h | 8 +- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 6c4260bc..5cc7594d 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -7,6 +7,7 @@ #include #include +#include #include #include "brad_sql_info.h" #include "brad_statement.h" @@ -50,23 +51,108 @@ arrow::Result> DecodeTransactionQuery( return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } -std::vector> TransformQueryResult( - std::vector query_result) { - std::vector> transformed_query_result; - for (const auto &row : query_result) { - std::vector transformed_row{}; - for (const auto &field : row) { - if (py::isinstance(field)) { - transformed_row.push_back(std::make_any(py::cast(field))); - } else if (py::isinstance(field)) { - transformed_row.push_back(std::make_any(py::cast(field))); - } else { - transformed_row.push_back(std::make_any(py::cast(field))); +arrow::Result> ResultToRecordBatch( + const std::vector &query_result, + const std::shared_ptr &schema) { + const size_t num_rows = query_result.size(); + + const size_t num_columns = schema->num_fields(); + std::vector> columns; + columns.reserve(num_columns); + + for (int field_ix = 0; field_ix < num_columns; ++field_ix) { + const auto &field_type = schema->field(field_ix)->type(); + if (field_type->Equals(arrow::int64())) { + arrow::Int64Builder int64builder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(int64builder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(int64builder.AppendNull()); + } } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish()); + columns.push_back(values); + + } else if (field_type->Equals(arrow::float32())) { + arrow::FloatBuilder floatbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(floatbuilder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(floatbuilder.AppendNull()); + } + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); + columns.push_back(values); + + } else if (field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) { + arrow::Decimal128Builder decimalbuilder(arrow::decimal(/*precision=*/10, /*scale=*/2)); + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK( + decimalbuilder.Append(arrow::Decimal128::FromString(*val).ValueOrDie())); + } else { + ARROW_RETURN_NOT_OK(decimalbuilder.AppendNull()); + } + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, decimalbuilder.Finish()); + columns.push_back(values); + + } else if (field_type->Equals(arrow::utf8())) { + arrow::StringBuilder stringbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional str = + py::cast>(query_result[row_ix][field_ix]); + if (str) { + ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); + } else { + ARROW_RETURN_NOT_OK(stringbuilder.AppendNull()); + } + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); + columns.push_back(values); + + } else if (field_type->Equals(arrow::date64())) { + arrow::Date64Builder datebuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(datebuilder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(datebuilder.AppendNull()); + } + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish()); + columns.push_back(values); + + } else if (field_type->Equals(arrow::null())) { + arrow::NullBuilder nullbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + ARROW_RETURN_NOT_OK(nullbuilder.AppendNull()); + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish()); + columns.push_back(values); } - transformed_query_result.push_back(transformed_row); } - return transformed_query_result; + + std::shared_ptr result_record_batch = + arrow::RecordBatch::Make(schema, num_rows, columns); + + return result_record_batch; } BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} @@ -125,25 +211,23 @@ arrow::Result> EncodeTransactionQuery(query_ticket)); std::shared_ptr result_schema; - std::vector> transformed_query_result; + std::shared_ptr result_record_batch; { py::gil_scoped_acquire guard; auto result = handle_query_(query); result_schema = ArrowSchemaFromBradSchema(result.second); - transformed_query_result = TransformQueryResult(result.first); + result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie(); } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(std::move(result_record_batch), result_schema)); query_data_.insert(query_ticket, statement); - ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); - std::vector endpoints{ FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; const bool ordered = false; - ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, + ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*result_schema, descriptor, endpoints, -1, diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 484ea216..ee6eaf21 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -15,6 +15,7 @@ #include "libcuckoo/cuckoohash_map.hh" #include +#include namespace brad { diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index e9ce1588..0db4a786 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -25,96 +25,27 @@ arrow::Result> BradStatement::Create( } arrow::Result> BradStatement::Create( - std::vector> query_result) { - std::shared_ptr result( - std::make_shared(query_result)); - return result; + std::shared_ptr result_record_batch, + std::shared_ptr schema) { + std::shared_ptr result( + std::make_shared(result_record_batch, schema)); + return result; } -BradStatement::BradStatement(std::vector> query_result) : - query_result_(std::move(query_result)) {} +BradStatement::BradStatement(std::shared_ptr result_record_batch, + std::shared_ptr schema) : + result_record_batch_(std::move(result_record_batch)), + schema_(std::move(schema)) {} BradStatement::~BradStatement() { } arrow::Result> BradStatement::GetSchema() const { - if (schema_) { - return schema_; - } - - std::vector> fields; - - if (query_result_.size() > 0) { - const std::vector &row = query_result_[0]; - - int counter = 0; - for (const auto &field : row) { - std::string field_type = field.type().name(); - if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); - } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); - } else { - fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); - } - } - } - - schema_ = arrow::schema(fields); return schema_; } arrow::Result> BradStatement::FetchResult() { - std::shared_ptr schema = GetSchema().ValueOrDie(); - - const int num_rows = query_result_.size(); - - std::vector> columns; - columns.reserve(schema->num_fields()); - - for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { - const auto &field = schema->fields()[field_ix]; - if (field->type() == arrow::int8()) { - arrow::Int8Builder int8builder; - int8_t values_raw[num_rows]; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); - } - ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows)); - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish()); - - columns.push_back(values); - } else if (field->type() == arrow::float32()) { - arrow::FloatBuilder floatbuilder; - float values_raw[num_rows]; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); - } - ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); - - columns.push_back(values); - } else if (field->type() == arrow::utf8()) { - arrow::StringBuilder stringbuilder; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const std::string* str = std::any_cast(&(query_result_[row_ix][field_ix])); - ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); - } - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); - } - } - - std::shared_ptr record_batch = - arrow::RecordBatch::Make(schema, - num_rows, - columns); - return record_batch; + return result_record_batch_; } std::string* BradStatement::GetBradStmt() const { return stmt_; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index b3dba2cc..6d296c16 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -26,9 +26,11 @@ class BradStatement { const std::string& sql); static arrow::Result> Create( - const std::vector>); + std::shared_ptr result_record_batch, + std::shared_ptr schema); - BradStatement(std::vector>); + BradStatement(std::shared_ptr, + std::shared_ptr); ~BradStatement(); @@ -41,7 +43,7 @@ class BradStatement { std::string* GetBradStmt() const; private: - std::vector> query_result_; + std::shared_ptr result_record_batch_; mutable std::shared_ptr schema_;