Skip to content

Commit

Permalink
Extract schema information from the underlying connections, add Arrow…
Browse files Browse the repository at this point in the history
… schema conversion utility (#497)

This PR modifies BRAD's connection code to extract schema information
(the column names and types) from the underlying database connections.
Each database has slightly different types, so we unify them in a "best
effort" way for now. Later on we will have better support for specific
SQL dialects.

**High level summary of the changes**
- Add `result_schema()` to BRAD cursors
- Define BRAD-specific `DataType`, `Field` and `Schema` classes
- Add a C++ helper function that converts a Python BRAD `Schema` into an
Arrow schema
- Various minor modifications to pass the Python schema to the Flight
SQL code
  • Loading branch information
geoffxy authored Apr 26, 2024
1 parent 255f8b1 commit 3b40183
Show file tree
Hide file tree
Showing 14 changed files with 420 additions and 28 deletions.
3 changes: 2 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pybind11_add_module(pybind_brad_server pybind/brad_server.cc
server/brad_sql_info.cc
server/brad_statement_batch_reader.cc
server/brad_statement.cc
server/brad_tables_schema_batch_reader.cc)
server/brad_tables_schema_batch_reader.cc
server/python_utils.cc)

target_link_libraries(pybind_brad_server
PRIVATE Arrow::arrow_shared
Expand Down
27 changes: 21 additions & 6 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#include <sstream>
#include <unordered_map>
#include <utility>
#include <stdexcept>

#include <arrow/array/builder_binary.h>
#include "brad_sql_info.h"
#include "brad_statement.h"
#include "brad_statement_batch_reader.h"
#include "brad_tables_schema_batch_reader.h"
#include "python_utils.h"
#include <arrow/flight/sql/server.h>
#include <arrow/scalar.h>
#include <arrow/util/checked_cast.h>
Expand All @@ -22,6 +24,8 @@ using arrow::internal::checked_cast;
using namespace arrow::flight;
using namespace arrow::flight::sql;

namespace py = pybind11;

std::string GetQueryTicket(
const std::string &autoincrement_id,
const std::string &transaction_id) {
Expand Down Expand Up @@ -82,21 +86,30 @@ std::shared_ptr<BradFlightSqlServer>
void BradFlightSqlServer::InitWrapper(
const std::string &host,
int port,
std::function<std::vector<py::tuple>(std::string)> handle_query) {
PythonRunQueryFn handle_query) {
auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie();
arrow::flight::FlightServerOptions options(location);

handle_query_ = handle_query;

this->Init(options);
const auto status = this->Init(options);
if (!status.ok()) {
throw std::runtime_error(status.message());
}
}

void BradFlightSqlServer::ServeWrapper() {
this->Serve();
const auto status = this->Serve();
if (!status.ok()) {
throw std::runtime_error(status.message());
}
}

void BradFlightSqlServer::ShutdownWrapper() {
this->Shutdown(nullptr);
const auto status = this->Shutdown(nullptr);
if (!status.ok()) {
throw std::runtime_error(status.message());
}
}

arrow::Result<std::unique_ptr<FlightInfo>>
Expand All @@ -111,12 +124,14 @@ arrow::Result<std::unique_ptr<FlightInfo>>
ARROW_ASSIGN_OR_RAISE(auto ticket,
EncodeTransactionQuery(query_ticket));

std::shared_ptr<arrow::Schema> result_schema;
std::vector<std::vector<std::any>> transformed_query_result;

{
py::gil_scoped_acquire guard;
std::vector<py::tuple> query_result = handle_query_(query);
transformed_query_result = TransformQueryResult(query_result);
auto result = handle_query_(query);
result_schema = ArrowSchemaFromBradSchema(result.second);
transformed_query_result = TransformQueryResult(result.first);
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result));
Expand Down
14 changes: 9 additions & 5 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <string>
#include <vector>
#include <utility>

#include <arrow/flight/sql/server.h>
#include "brad_statement.h"
Expand All @@ -15,11 +16,14 @@

#include <pybind11/pybind11.h>

namespace py = pybind11;
using namespace pybind11::literals;

namespace brad {

// The type of a Python function that will execute the given SQL query (given as
// a string). The function returns the results and a schema object.
//
// NOTE: The GIL must be held when invoking this function.
using PythonRunQueryFn = std::function<std::pair<std::vector<pybind11::tuple>, pybind11::object>(std::string)>;

class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
public:
explicit BradFlightSqlServer();
Expand All @@ -30,7 +34,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {

void InitWrapper(const std::string &host,
int port,
std::function<std::vector<py::tuple>(std::string)>);
PythonRunQueryFn handle_query);

void ServeWrapper();

Expand All @@ -48,7 +52,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
const arrow::flight::sql::StatementQueryTicket &command) override;

private:
std::function<std::vector<py::tuple>(std::string)> handle_query_;
PythonRunQueryFn handle_query_;

libcuckoo::cuckoohash_map<std::string, std::shared_ptr<BradStatement>> query_data_;

Expand Down
62 changes: 62 additions & 0 deletions cpp/server/python_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "python_utils.h"

#include <vector>
#include <arrow/type.h>
#include <iostream>

namespace py = pybind11;

namespace {

std::shared_ptr<arrow::DataType> ArrowDataTypeFromBradDataType(const pybind11::object& data_type) {
// NOTE: If you change values here, make sure to change
// `brad.connection.schema.DataType` as well.
const int64_t value = py::cast<int64_t>(data_type.attr("value"));
switch (value) {
// DataType.Integer
case 1:
return arrow::int64();

// DataType.Float
case 2:
return arrow::float32();

// DataType.Decimal
case 3:
// Ideally these values should be stored with the data type and not be
// hardcoded here.
return arrow::decimal(/*precision=*/10, /*scale=*/2);

// DataType.String
case 4:
return arrow::utf8();

// DataType.Timestamp
case 5:
return arrow::date64();

default:
case 0:
return arrow::null();
}
}

} // namespace

namespace brad {

std::shared_ptr<arrow::Schema> ArrowSchemaFromBradSchema(const pybind11::object& schema) {
const size_t num_fields = py::cast<size_t>(schema.attr("num_fields"));
std::vector<std::shared_ptr<arrow::Field>> fields;
fields.reserve(num_fields);

for (const auto& brad_field : schema) {
std::string field_name = py::cast<std::string>(brad_field.attr("name"));
std::shared_ptr<arrow::DataType> data_type = ArrowDataTypeFromBradDataType(brad_field.attr("data_type"));
fields.push_back(arrow::field(std::move(field_name), std::move(data_type)));
}

return arrow::schema(std::move(fields));
}

} // namespace brad
16 changes: 16 additions & 0 deletions cpp/server/python_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include <memory>
#include <arrow/type.h>
#include <pybind11/pybind11.h>

namespace brad {

// Converts a `brad.connection.schema.Schema` Python object into an
// `arrow::Schema`. The passed in `schema` must be an instance of
// `brad.connection.schema.Schema`.
//
// NOTE: The GIL must be held while running this function.
std::shared_ptr<arrow::Schema> ArrowSchemaFromBradSchema(const pybind11::object& schema);

} // namespace brad
6 changes: 6 additions & 0 deletions src/brad/connection/cursor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Tuple, Optional, List, Iterator, AsyncIterator, Iterable
from .schema import Schema


Row = Tuple[Any, ...]
Expand Down Expand Up @@ -48,6 +49,11 @@ def fetchone_sync(self) -> Optional[Row]:
def fetchall_sync(self) -> List[Row]:
raise NotImplementedError

def result_schema(self, results: Optional[List[Row]] = None) -> Schema:
# Note that `results` only needs to be passed in when running in stub
# mode (needed for type deduction).
raise NotImplementedError

def __iter__(self) -> Iterator[Row]:
def do_iteration():
while True:
Expand Down
28 changes: 27 additions & 1 deletion src/brad/connection/odbc_cursor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import datetime
import decimal
from typing import Any, Optional, List, Iterable

from .cursor import Cursor, Row
from .schema import Schema, Field, DataType


class OdbcCursor(Cursor):
Expand Down Expand Up @@ -39,7 +42,30 @@ def fetchone_sync(self) -> Optional[Row]:
return self._impl.fetchone()

def fetchall_sync(self) -> List[Row]:
return self._impl.fetchall()
res = self._impl.fetchall()
return res

def result_schema(self, results: Optional[List[Row]] = None) -> Schema:
fields = []
for column_metadata in self._impl.description:
column_name = column_metadata[0]
odbc_type = column_metadata[1]
if odbc_type is int:
brad_type = DataType.Integer
elif odbc_type is str:
brad_type = DataType.String
elif odbc_type is float:
brad_type = DataType.Float
elif odbc_type is bool:
brad_type = DataType.Integer
elif odbc_type is decimal.Decimal:
brad_type = DataType.Decimal
elif odbc_type is datetime.datetime:
brad_type = DataType.Timestamp
else:
brad_type = DataType.Unknown
fields.append(Field(name=column_name, data_type=brad_type))
return Schema(fields)

def commit_sync(self) -> None:
self._impl.commit()
Expand Down
39 changes: 39 additions & 0 deletions src/brad/connection/psycopg_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, List, Iterable

from .cursor import Cursor, Row
from .schema import Schema, Field, DataType


class PsycopgCursor(Cursor):
Expand Down Expand Up @@ -43,8 +44,46 @@ def fetchone_sync(self) -> Optional[Row]:
def fetchall_sync(self) -> List[Row]:
return self._impl.fetchall()

def result_schema(self, results: Optional[List[Row]] = None) -> Schema:
if self._impl.description is None:
return Schema.empty()

fields = []
for column_metadata in self._impl.description:
try:
brad_type = _POSTGRESQL_OID_TO_BRAD_TYPE[column_metadata.type_code]
except KeyError:
brad_type = DataType.Unknown
fields.append(Field(name=column_metadata.name, data_type=brad_type))
return Schema(fields)

def commit_sync(self) -> None:
self._conn.commit()

def rollback_sync(self) -> None:
self._conn.rollback()


# Use iter(self._impl.adapters.types) to retrieve the types supported by the
# underlying database.
_POSTGRESQL_OID_TO_BRAD_TYPE = {
# Integer types.
16: DataType.Integer, # bool
21: DataType.Integer, # int2
23: DataType.Integer, # int4
20: DataType.Integer, # int8
26: DataType.Integer, # oid
# Float types.
700: DataType.Float, # float4
701: DataType.Float, # float8
# Fixed precision types.
1700: DataType.Decimal,
# String types.
1042: DataType.String, # bpchar
25: DataType.String, # text
1043: DataType.String, # varchar
# Timestamp types.
1114: DataType.Timestamp, # timestamp
1083: DataType.Timestamp, # time
# N.B. We do not currently support date types.
}
24 changes: 24 additions & 0 deletions src/brad/connection/pyathena_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Iterable, Optional, List

from .cursor import Cursor, Row
from .schema import Schema, Field, DataType


class PyAthenaCursor(Cursor):
Expand Down Expand Up @@ -42,6 +43,29 @@ def fetchone_sync(self) -> Optional[Row]:
def fetchall_sync(self) -> List[Row]:
return self._impl.fetchall() # type: ignore

def result_schema(self, results: Optional[List[Row]] = None) -> Schema:
if self._impl.description is None:
return Schema.empty()

fields = []
for column_metadata in self._impl.description:
column_name = column_metadata[0]
athena_type = column_metadata[1]
if athena_type == "integer":
brad_type = DataType.Integer
elif athena_type == "varchar":
brad_type = DataType.String
elif athena_type == "float":
brad_type = DataType.Float
elif athena_type == "timestamp":
brad_type = DataType.Timestamp
elif athena_type == "decimal":
brad_type = DataType.Decimal
else:
brad_type = DataType.Unknown
fields.append(Field(name=column_name, data_type=brad_type))
return Schema(fields)

def commit_sync(self) -> None:
pass

Expand Down
Loading

0 comments on commit 3b40183

Please sign in to comment.