diff --git a/cpp/pybind/brad_server.cc b/cpp/pybind/brad_server.cc index 881b7664..be006996 100644 --- a/cpp/pybind/brad_server.cc +++ b/cpp/pybind/brad_server.cc @@ -13,5 +13,10 @@ PYBIND11_MODULE(pybind_brad_server, m) { brad_server .def(py::init<>()) - .def("create", &brad::BradFlightSqlServer::Create); + .def("create", &brad::BradFlightSqlServer::Create) + .def("init", &brad::BradFlightSqlServer::InitWrapper) + .def("serve", + &brad::BradFlightSqlServer::ServeWrapper, + py::call_guard()) + .def("shutdown", &brad::BradFlightSqlServer::ShutdownWrapper); } diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index a2d69902..19321e76 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -51,12 +51,26 @@ BradFlightSqlServer::~BradFlightSqlServer() = default; std::shared_ptr BradFlightSqlServer::Create() { // std::shared_ptr result(new BradFlightSqlServer()); - std::shared_ptr result = - std::make_shared(); - for (const auto &id_to_result : GetSqlInfoResultMap()) { - result->RegisterSqlInfo(id_to_result.first, id_to_result.second); - } - return result; + std::shared_ptr result = + std::make_shared(); + for (const auto &id_to_result : GetSqlInfoResultMap()) { + result->RegisterSqlInfo(id_to_result.first, id_to_result.second); + } + return result; +} + +void BradFlightSqlServer::InitWrapper(const std::string &host, int port) { + auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie(); + arrow::flight::FlightServerOptions options(location); + this->Init(options); +} + +void BradFlightSqlServer::ServeWrapper() { + this->Serve(); +} + +void BradFlightSqlServer::ShutdownWrapper() { + this->Shutdown(nullptr); } arrow::Result> diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 70c64893..f6db4cbf 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -17,6 +17,12 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { static std::shared_ptr Create(); + void InitWrapper(const std::string &host, int port); + + void ServeWrapper(); + + void ShutdownWrapper(); + arrow::Result> GetFlightInfoStatement( const arrow::flight::ServerCallContext &context, diff --git a/src/brad/front_end/flight_sql_server.py b/src/brad/front_end/flight_sql_server.py new file mode 100644 index 00000000..eb736e20 --- /dev/null +++ b/src/brad/front_end/flight_sql_server.py @@ -0,0 +1,26 @@ +import logging +import threading + +# pylint: disable-next=import-error,no-name-in-module,unused-import +import brad.native.pybind_brad_server as brad_server + +logger = logging.getLogger(__name__) + + +class BradFlightSqlServer: + def __init__(self, host: str, port: int) -> None: + self._flight_sql_server = brad_server.BradFlightSqlServer() + self._flight_sql_server.init(host, port) + self._thread = threading.Thread(name="BradFlightSqlServer", target=self._serve) + + def start(self) -> None: + self._thread.start() + + def stop(self) -> None: + logger.info("BRAD FlightSQL server stopping...") + self._flight_sql_server.shutdown() + self._thread.join() + logger.info("BRAD FlightSQL server stopped.") + + def _serve(self) -> None: + self._flight_sql_server.serve() diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index e8f74a45..001f5371 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -87,10 +87,11 @@ def __init__( output_queue: mp.Queue, ): if BradFrontEnd.native_server_is_supported(): - # pylint: disable-next=import-error,no-name-in-module - import brad.native.pybind_brad_server as brad_server + from brad.front_end.flight_sql_server import BradFlightSqlServer - self._flight_sql_server = brad_server.BradFlightSqlServer.create() + self._flight_sql_server: Optional[BradFlightSqlServer] = ( + BradFlightSqlServer(host="0.0.0.0", port=31337) + ) else: self._flight_sql_server = None @@ -191,6 +192,10 @@ def __init__( async def serve_forever(self): await self._run_setup() + + # Start FlightSQL server + self._flight_sql_server.start() + try: grpc_server = grpc.aio.server() brad_grpc.add_BradServicer_to_server(BradGrpc(self), grpc_server) @@ -281,6 +286,11 @@ async def _set_up_router(self) -> None: async def _run_teardown(self): logger.debug("Starting BRAD front end _run_teardown()") + + # Shutdown FlightSQL server + if self._flight_sql_server: + self._flight_sql_server.stop() + await self._sessions.end_all_sessions() # Important for unblocking our message reader thread.