diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8c3438ab..a55a61d6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.16) +option(BRAD_BUILD_EXPERIMENTAL OFF "Set to build the experimental code.") + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") @@ -16,6 +18,10 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) +if(BRAD_BUILD_EXPERIMENTAL) + add_subdirectory(experimental) +endif() + add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc diff --git a/cpp/experimental/CMakeLists.txt b/cpp/experimental/CMakeLists.txt new file mode 100644 index 00000000..c2e70a5c --- /dev/null +++ b/cpp/experimental/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(proxy_socket proxy_socket.cc) +target_link_libraries(proxy_socket PRIVATE gflags) diff --git a/cpp/experimental/proxy_socket.cc b/cpp/experimental/proxy_socket.cc new file mode 100644 index 00000000..76768498 --- /dev/null +++ b/cpp/experimental/proxy_socket.cc @@ -0,0 +1,276 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +DEFINE_int32(port, 31337, "Port that this server should listen on."); + +DEFINE_int32(proxy_to_port, 5439, "Port that this server should proxy its connection to."); +DEFINE_string(proxy_to_host, "", "The host that this server should proxy its connection to."); + +namespace { + +class Socket { + public: + static Socket Connect(const std::string& host, const uint16_t port) { + const int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + perror("Socket failed."); + throw std::runtime_error("Socket failed."); + } + + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(port); + + if(inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) < 0) { + perror("Host conversion."); + throw std::runtime_error("Host conversion."); + } + + if (connect(fd, reinterpret_cast(&serv_addr), sizeof(serv_addr)) < 0) { + perror("Connect failed"); + throw std::runtime_error("Connect failed."); + } + + return Socket(fd); + } + + // No copying or copy assignment. + Socket(const Socket&) = delete; + Socket& operator=(const Socket&) = delete; + + ~Socket() { close(fd_); } + + int fd() const { return fd_; } + + private: + friend class ServerSocket; + explicit Socket(int fd) : fd_(fd) {} + + int fd_; +}; + +class ServerSocket { + public: + explicit ServerSocket(uint16_t port) : port_(port), fd_(-1) { + struct sockaddr_in address; + int opt = 1; + int addrlen = sizeof(address); + + // Creating socket file descriptor + fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (fd_ == 0) { + perror("Socket failed"); + throw std::runtime_error("Socket failed"); + } + + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt))) { + perror("setsockopt"); + throw std::runtime_error("setsockopt"); + } + + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + if (bind(fd_, reinterpret_cast(&address), sizeof(address)) < 0) { + perror("bind failed"); + throw std::runtime_error("bind failed"); + } + + if (listen(fd_, 1) < 0) { + perror("listen"); + throw std::runtime_error("listen failed"); + } + } + + Socket Accept() const { + struct sockaddr_in address; + socklen_t addrlen = sizeof(address); + const int new_fd = accept(fd_, reinterpret_cast(&address), &addrlen); + if (new_fd < 0) { + perror("Accept failed"); + throw std::runtime_error("Accept failed"); + } + return Socket(new_fd); + } + + ~ServerSocket() { close(fd_); } + + // No copying or copy assignment. + ServerSocket(const ServerSocket&) = delete; + ServerSocket& operator=(const ServerSocket&) = delete; + + int fd() const { return fd_; } + + private: + uint16_t port_; + int fd_; +}; + +class SentinelPipe { + public: + SentinelPipe() { + if (pipe(fd_) < 0) { + perror("Pipe failed."); + throw std::runtime_error("Pipe failed"); + } + } + + ~SentinelPipe() { + for (int i = 0; i < 2; ++i) { + if (fd_[i] > 0) { + close(fd_[i]); + fd_[i] = -1; + } + } + } + + SentinelPipe(const SentinelPipe&) = delete; + SentinelPipe& operator=(const SentinelPipe&) = delete; + + int read_fd() const { return fd_[0]; } + int write_fd() const { return fd_[1]; } + + private: + int fd_[2]; +}; + +class Buffer { + public: + Buffer(size_t size) : buf_(nullptr) { + buf_ = new uint8_t[size]; + } + + ~Buffer() { + if (buf_ == nullptr) return; + delete buf_; + buf_ = nullptr; + } + + uint8_t* buffer() const { return buf_; } + + private: + uint8_t* buf_; +}; + +std::function g_handle_signal; + +void signal_wrapper(int signal) { + if (!g_handle_signal) return; + g_handle_signal(signal); +} + +} // namespace + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Proxies TCP connections."); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_proxy_to_host.empty()) { + std::cerr << "ERROR: Must provide a value for --proxy-to-host" << std::endl; + return 1; + } + + // Workflow: + // - Start a socket listening for connections on `port` + // - Once we accept one connection, open a socket to the proxied-to host/port + // - Shuffle bytes to and from the two connections + // - Close the sockets on Ctrl-C or when there is an EOF + + ServerSocket server(FLAGS_port); + std::cerr << "Listening for a connection on port " << FLAGS_port << std::endl; + + const Socket to_client = server.Accept(); + std::cerr << "Accepted client connection." << std::endl; + + std::cerr << "Connecting to " << FLAGS_proxy_to_host << ":" << FLAGS_proxy_to_port << std::endl; + const Socket to_underlying = Socket::Connect(FLAGS_proxy_to_host, FLAGS_proxy_to_port); + std::cerr << "Connection succeeded." << std::endl; + + // Handle early exit (Ctrl+C or SIGTERM). + SentinelPipe sentinel; + g_handle_signal = [&sentinel](int signal) { + char null_char = '\0'; + write(sentinel.write_fd(), &null_char, sizeof(null_char)); + }; + std::signal(SIGINT, signal_wrapper); + std::signal(SIGTERM, signal_wrapper); + + const size_t buffer_size = 4096; + Buffer client_to_underlying(buffer_size), underlying_to_client(buffer_size), scratch(buffer_size); + + fd_set descriptors; + const int nfds = std::max(std::max(to_client.fd(), to_underlying.fd()), sentinel.read_fd()) + 1; + while (true) { + FD_ZERO(&descriptors); + FD_SET(to_client.fd(), &descriptors); + FD_SET(to_underlying.fd(), &descriptors); + FD_SET(sentinel.read_fd(), &descriptors); + + const int result = select(nfds, &descriptors, nullptr, nullptr, nullptr); + if (result < 0) { + perror("Select"); + break; + } + + if (FD_ISSET(to_client.fd(), &descriptors)) { + // Forward client message to underlying. + const ssize_t bytes_read = read(to_client.fd(), client_to_underlying.buffer(), buffer_size); + if (bytes_read < 0) { + perror("Read from client"); + break; + } + + ssize_t left_to_write = bytes_read; + uint8_t* buffer = client_to_underlying.buffer(); + while (left_to_write > 0) { + const ssize_t bytes_written = write(to_underlying.fd(), buffer, left_to_write); + if (bytes_written < 0) { + perror("Write to underlying"); + break; + } + left_to_write -= bytes_written; + buffer += bytes_written; + } + } + + if (FD_ISSET(to_underlying.fd(), &descriptors)) { + // Forward underlying message to client. + const ssize_t bytes_read = read(to_underlying.fd(), underlying_to_client.buffer(), buffer_size); + if (bytes_read < 0) { + perror("Read from underlying"); + break; + } + + ssize_t left_to_write = bytes_read; + uint8_t* buffer = underlying_to_client.buffer(); + while (left_to_write > 0) { + const ssize_t bytes_written = write(to_client.fd(), buffer, left_to_write); + if (bytes_written < 0) { + perror("Write to client"); + break; + } + left_to_write -= bytes_written; + buffer += bytes_written; + } + } + + if (FD_ISSET(sentinel.read_fd(), &descriptors)) { + read(sentinel.read_fd(), scratch.buffer(), 1); + break; + } + } + + std::cerr << "Done and exiting." << std::endl; + return 0; +} diff --git a/experiments/18-proxy/odbc_noop.py b/experiments/18-proxy/odbc_noop.py new file mode 100644 index 00000000..7963ac2a --- /dev/null +++ b/experiments/18-proxy/odbc_noop.py @@ -0,0 +1,33 @@ +import argparse + +from brad.config.engine import Engine +from brad.config.file import ConfigFile +from brad.connection.connection import Connection +from brad.connection.factory import ConnectionFactory +from brad.connection.odbc_connection import OdbcConnection + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--address", type=str, required=True) + parser.add_argument("--port", type=str, required=True) + parser.add_argument("--physical-config-file", type=str, required=True) + args = parser.parse_args() + + config = ConfigFile.load_from_physical_config(args.physical_config_file) + cstr = ConnectionFactory._pg_aurora_odbc_connection_string( + args.address, + args.port, + config.get_connection_details(Engine.Aurora), + schema_name=None, + ) + cxn: Connection = OdbcConnection.connect_sync(cstr, autocommit=True, timeout_s=30) + cursor = cxn.cursor_sync() + cursor.execute_sync("SELECT 1") + print(cursor.fetchall_sync()) + + cxn.close_sync() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 9b5261ee..c321b6d7 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ "pandas", "scikit-learn==1.3.0", "types-pytz", - "numpy", + "numpy==1.25.2", "imbalanced-learn", "redshift_connector", "tabulate",