From 615d8e746b67a42dae225969cd20845082202196 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 27 Apr 2024 19:31:12 -0400 Subject: [PATCH 1/6] Use record batch in statements to avoid type conversions between Arrow and cpp --- cpp/server/brad_server_simple.cc | 65 ++++++++++++++++++++--- cpp/server/brad_statement.cc | 89 ++++---------------------------- cpp/server/brad_statement.h | 8 +-- 3 files changed, 74 insertions(+), 88 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 6c4260bc..43bb809d 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -7,6 +7,7 @@ #include #include +#include #include #include "brad_sql_info.h" #include "brad_statement.h" @@ -69,6 +70,60 @@ std::vector> TransformQueryResult( return transformed_query_result; } +arrow::Result> +ResultToRecordBatch(std::vector query_result, std::shared_ptr schema) { + // TODO: Handle edge case with empty vector + const int num_rows = query_result.size(); + + const auto &row = query_result[0]; + const int num_columns = row.size(); + std::vector> columns; + columns.reserve(num_columns); + + // TODO: Use schema fields instead of inferring from result + for (int field_ix = 0; field_ix < num_columns; ++field_ix) { + if (py::isinstance(row[field_ix])) { + arrow::Int64Builder int64builder; + int64_t values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = py::cast(query_result[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(int64builder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish()); + columns.push_back(values); + + } else if (py::isinstance(row[field_ix])) { + arrow::FloatBuilder floatbuilder; + float values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = py::cast(query_result[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); + columns.push_back(values); + + } else { + arrow::StringBuilder stringbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::string str = py::cast(query_result[row_ix][field_ix]); + ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size())); + } + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); + } + } + + std::shared_ptr result_record_batch = + arrow::RecordBatch::Make(schema, num_rows, columns); + + return result_record_batch; +} + BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -125,25 +180,23 @@ arrow::Result> EncodeTransactionQuery(query_ticket)); std::shared_ptr result_schema; - std::vector> transformed_query_result; + std::shared_ptr result_record_batch; { py::gil_scoped_acquire guard; auto result = handle_query_(query); result_schema = ArrowSchemaFromBradSchema(result.second); - transformed_query_result = TransformQueryResult(result.first); + result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie(); } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema)); query_data_.insert(query_ticket, statement); - ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); - std::vector endpoints{ FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; const bool ordered = false; - ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, + ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*result_schema, descriptor, endpoints, -1, diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index e9ce1588..0db4a786 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -25,96 +25,27 @@ arrow::Result> BradStatement::Create( } arrow::Result> BradStatement::Create( - std::vector> query_result) { - std::shared_ptr result( - std::make_shared(query_result)); - return result; + std::shared_ptr result_record_batch, + std::shared_ptr schema) { + std::shared_ptr result( + std::make_shared(result_record_batch, schema)); + return result; } -BradStatement::BradStatement(std::vector> query_result) : - query_result_(std::move(query_result)) {} +BradStatement::BradStatement(std::shared_ptr result_record_batch, + std::shared_ptr schema) : + result_record_batch_(std::move(result_record_batch)), + schema_(std::move(schema)) {} BradStatement::~BradStatement() { } arrow::Result> BradStatement::GetSchema() const { - if (schema_) { - return schema_; - } - - std::vector> fields; - - if (query_result_.size() > 0) { - const std::vector &row = query_result_[0]; - - int counter = 0; - for (const auto &field : row) { - std::string field_type = field.type().name(); - if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); - } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); - } else { - fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); - } - } - } - - schema_ = arrow::schema(fields); return schema_; } arrow::Result> BradStatement::FetchResult() { - std::shared_ptr schema = GetSchema().ValueOrDie(); - - const int num_rows = query_result_.size(); - - std::vector> columns; - columns.reserve(schema->num_fields()); - - for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { - const auto &field = schema->fields()[field_ix]; - if (field->type() == arrow::int8()) { - arrow::Int8Builder int8builder; - int8_t values_raw[num_rows]; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); - } - ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows)); - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish()); - - columns.push_back(values); - } else if (field->type() == arrow::float32()) { - arrow::FloatBuilder floatbuilder; - float values_raw[num_rows]; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); - } - ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); - - columns.push_back(values); - } else if (field->type() == arrow::utf8()) { - arrow::StringBuilder stringbuilder; - for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const std::string* str = std::any_cast(&(query_result_[row_ix][field_ix])); - ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); - } - - std::shared_ptr values; - ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); - } - } - - std::shared_ptr record_batch = - arrow::RecordBatch::Make(schema, - num_rows, - columns); - return record_batch; + return result_record_batch_; } std::string* BradStatement::GetBradStmt() const { return stmt_; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index b3dba2cc..6d296c16 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -26,9 +26,11 @@ class BradStatement { const std::string& sql); static arrow::Result> Create( - const std::vector>); + std::shared_ptr result_record_batch, + std::shared_ptr schema); - BradStatement(std::vector>); + BradStatement(std::shared_ptr, + std::shared_ptr); ~BradStatement(); @@ -41,7 +43,7 @@ class BradStatement { std::string* GetBradStmt() const; private: - std::vector> query_result_; + std::shared_ptr result_record_batch_; mutable std::shared_ptr schema_; From 4aca7ea1eec75296f333bbb953f83b1864a22b23 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 27 Apr 2024 21:38:31 -0400 Subject: [PATCH 2/6] Build record batch from schema fields --- cpp/server/brad_server_simple.cc | 33 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 43bb809d..51dda75f 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -7,6 +7,8 @@ #include #include +#include + #include #include #include "brad_sql_info.h" @@ -72,36 +74,31 @@ std::vector> TransformQueryResult( arrow::Result> ResultToRecordBatch(std::vector query_result, std::shared_ptr schema) { - // TODO: Handle edge case with empty vector const int num_rows = query_result.size(); - const auto &row = query_result[0]; - const int num_columns = row.size(); + const int num_columns = schema->num_fields(); std::vector> columns; columns.reserve(num_columns); - // TODO: Use schema fields instead of inferring from result for (int field_ix = 0; field_ix < num_columns; ++field_ix) { - if (py::isinstance(row[field_ix])) { + const auto &field_type = schema->field(field_ix)->type(); + if (field_type->Equals(arrow::int64())) { arrow::Int64Builder int64builder; - int64_t values_raw[num_rows]; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = py::cast(query_result[row_ix][field_ix]); + const int64_t val = py::cast(query_result[row_ix][field_ix]); + // TODO: How do we check for null values in ints or floats? + ARROW_RETURN_NOT_OK(int64builder.Append(val)); } - ARROW_RETURN_NOT_OK(int64builder.AppendValues(values_raw, num_rows)); - std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish()); columns.push_back(values); - } else if (py::isinstance(row[field_ix])) { + } else if (field_type->Equals(arrow::float32())) { arrow::FloatBuilder floatbuilder; - float values_raw[num_rows]; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - values_raw[row_ix] = py::cast(query_result[row_ix][field_ix]); + const float val = py::cast(query_result[row_ix][field_ix]); + ARROW_RETURN_NOT_OK(floatbuilder.Append(val)); } - ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); - std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); columns.push_back(values); @@ -110,11 +107,15 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptr(query_result[row_ix][field_ix]); - ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size())); + if (str.empty()) { + ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size())); + } else { + ARROW_RETURN_NOT_OK(stringbuilder.AppendNull()); + } } - std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); + columns.push_back(values); } } From 9c70ec65600680b389fcd7faa2a7e34cfd945c00 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 27 Apr 2024 21:44:45 -0400 Subject: [PATCH 3/6] Remove TransformQueryResult --- cpp/server/brad_server_simple.cc | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 51dda75f..aac21999 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -7,8 +7,6 @@ #include #include -#include - #include #include #include "brad_sql_info.h" @@ -53,25 +51,6 @@ arrow::Result> DecodeTransactionQuery( return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } -std::vector> TransformQueryResult( - std::vector query_result) { - std::vector> transformed_query_result; - for (const auto &row : query_result) { - std::vector transformed_row{}; - for (const auto &field : row) { - if (py::isinstance(field)) { - transformed_row.push_back(std::make_any(py::cast(field))); - } else if (py::isinstance(field)) { - transformed_row.push_back(std::make_any(py::cast(field))); - } else { - transformed_row.push_back(std::make_any(py::cast(field))); - } - } - transformed_query_result.push_back(transformed_row); - } - return transformed_query_result; -} - arrow::Result> ResultToRecordBatch(std::vector query_result, std::shared_ptr schema) { const int num_rows = query_result.size(); From a7ac46593520e8c0788a2f035ac490b76ba1b2a9 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 27 Apr 2024 22:38:31 -0400 Subject: [PATCH 4/6] Expand field type checks in ResultToRecordBatch --- cpp/server/brad_server_simple.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index aac21999..af9a8ef3 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -64,7 +64,7 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptrEquals(arrow::int64())) { arrow::Int64Builder int64builder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const int64_t val = py::cast(query_result[row_ix][field_ix]); + const int64_t val = py::cast(query_result[row_ix][field_ix]); // TODO: How do we check for null values in ints or floats? ARROW_RETURN_NOT_OK(int64builder.Append(val)); } @@ -72,7 +72,9 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptrEquals(arrow::float32())) { + } else if (field_type->Equals(arrow::float32()) || + // TODO: Should not hardcode precision and scale values + field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) { arrow::FloatBuilder floatbuilder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { const float val = py::cast(query_result[row_ix][field_ix]); @@ -82,7 +84,7 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptrEquals(arrow::utf8())) { arrow::StringBuilder stringbuilder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { const std::string str = py::cast(query_result[row_ix][field_ix]); @@ -95,6 +97,17 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); columns.push_back(values); + + } else if (field_type->Equals(arrow::date64())) { + arrow::Date64Builder datebuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const int64_t val = py::cast(query_result[row_ix][field_ix]); + ARROW_RETURN_NOT_OK(datebuilder.Append(val)); + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish()); + columns.push_back(values); + } } From ea7c72d162eb6f6bc594a3500a936dd497d4ae2f Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sun, 28 Apr 2024 20:48:40 -0400 Subject: [PATCH 5/6] Separate Decimal field type handler, check for Null fields, handle nulltype data --- cpp/server/brad_server_simple.cc | 74 ++++++++++++++++++++++++-------- cpp/server/brad_server_simple.h | 1 + 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index af9a8ef3..6f79469e 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -51,11 +51,12 @@ arrow::Result> DecodeTransactionQuery( return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } -arrow::Result> -ResultToRecordBatch(std::vector query_result, std::shared_ptr schema) { - const int num_rows = query_result.size(); +arrow::Result> ResultToRecordBatch( + const std::vector &query_result, + const std::shared_ptr &schema) { + const size_t num_rows = query_result.size(); - const int num_columns = schema->num_fields(); + const size_t num_columns = schema->num_fields(); std::vector> columns; columns.reserve(num_columns); @@ -64,32 +65,56 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptrEquals(arrow::int64())) { arrow::Int64Builder int64builder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const int64_t val = py::cast(query_result[row_ix][field_ix]); - // TODO: How do we check for null values in ints or floats? - ARROW_RETURN_NOT_OK(int64builder.Append(val)); + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(int64builder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(int64builder.AppendNull()); + } } std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish()); columns.push_back(values); - } else if (field_type->Equals(arrow::float32()) || - // TODO: Should not hardcode precision and scale values - field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) { + } else if (field_type->Equals(arrow::float32())) { arrow::FloatBuilder floatbuilder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const float val = py::cast(query_result[row_ix][field_ix]); - ARROW_RETURN_NOT_OK(floatbuilder.Append(val)); + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(floatbuilder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(floatbuilder.AppendNull()); + } } std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); columns.push_back(values); + } else if (field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) { + arrow::Decimal128Builder decimalbuilder(arrow::decimal(/*precision=*/10, /*scale=*/2)); + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK( + decimalbuilder.Append(arrow::Decimal128::FromString(*val).ValueOrDie())); + } else { + ARROW_RETURN_NOT_OK(decimalbuilder.AppendNull()); + } + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, decimalbuilder.Finish()); + columns.push_back(values); + } else if (field_type->Equals(arrow::utf8())) { arrow::StringBuilder stringbuilder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const std::string str = py::cast(query_result[row_ix][field_ix]); - if (str.empty()) { - ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size())); + const std::optional str = + py::cast>(query_result[row_ix][field_ix]); + if (str) { + ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); } else { ARROW_RETURN_NOT_OK(stringbuilder.AppendNull()); } @@ -101,13 +126,26 @@ ResultToRecordBatch(std::vector query_result, std::shared_ptrEquals(arrow::date64())) { arrow::Date64Builder datebuilder; for (int row_ix = 0; row_ix < num_rows; ++row_ix) { - const int64_t val = py::cast(query_result[row_ix][field_ix]); - ARROW_RETURN_NOT_OK(datebuilder.Append(val)); + const std::optional val = + py::cast>(query_result[row_ix][field_ix]); + if (val) { + ARROW_RETURN_NOT_OK(datebuilder.Append(*val)); + } else { + ARROW_RETURN_NOT_OK(datebuilder.AppendNull()); + } } std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish()); columns.push_back(values); + } else if (field_type->Equals(arrow::null())) { + arrow::NullBuilder nullbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + ARROW_RETURN_NOT_OK(nullbuilder.AppendNull()); + } + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish()); + columns.push_back(values); } } @@ -179,7 +217,7 @@ arrow::Result> py::gil_scoped_acquire guard; auto result = handle_query_(query); result_schema = ArrowSchemaFromBradSchema(result.second); - result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie(); + result_record_batch = ResultToRecordBatch(std::move(result.first), result_schema).ValueOrDie(); } ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema)); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 484ea216..ee6eaf21 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -15,6 +15,7 @@ #include "libcuckoo/cuckoohash_map.hh" #include +#include namespace brad { From 6dbd8c22d15e36e6917b9b45ea40a140764f1653 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Tue, 30 Apr 2024 21:37:56 -0400 Subject: [PATCH 6/6] Address comments regarding std::move --- cpp/server/brad_server_simple.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 6f79469e..5cc7594d 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -217,10 +217,10 @@ arrow::Result> py::gil_scoped_acquire guard; auto result = handle_query_(query); result_schema = ArrowSchemaFromBradSchema(result.second); - result_record_batch = ResultToRecordBatch(std::move(result.first), result_schema).ValueOrDie(); + result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie(); } - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema)); + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(std::move(result_record_batch), result_schema)); query_data_.insert(query_ticket, statement); std::vector endpoints{