Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check in experimental byte-level proxying code #512

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
cmake_minimum_required(VERSION 3.16)

option(BRAD_BUILD_EXPERIMENTAL OFF "Set to build the experimental code.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
Expand All @@ -16,6 +18,10 @@ find_package(Boost REQUIRED)

add_subdirectory(third_party)

if(BRAD_BUILD_EXPERIMENTAL)
add_subdirectory(experimental)
endif()

add_library(sqlite_server_lib OBJECT
sqlite_server/sqlite_server.cc
sqlite_server/sqlite_sql_info.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_executable(proxy_socket proxy_socket.cc)
target_link_libraries(proxy_socket PRIVATE gflags)
276 changes: 276 additions & 0 deletions cpp/experimental/proxy_socket.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
#include <iostream>
#include <stdexcept>
#include <csignal>
#include <functional>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>

#include <gflags/gflags.h>

DEFINE_int32(port, 31337, "Port that this server should listen on.");

DEFINE_int32(proxy_to_port, 5439, "Port that this server should proxy its connection to.");
DEFINE_string(proxy_to_host, "", "The host that this server should proxy its connection to.");

namespace {

class Socket {
public:
static Socket Connect(const std::string& host, const uint16_t port) {
const int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
perror("Socket failed.");
throw std::runtime_error("Socket failed.");
}

struct sockaddr_in serv_addr;
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(port);

if(inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) < 0) {
perror("Host conversion.");
throw std::runtime_error("Host conversion.");
}

if (connect(fd, reinterpret_cast<struct sockaddr *>(&serv_addr), sizeof(serv_addr)) < 0) {
perror("Connect failed");
throw std::runtime_error("Connect failed.");
}

return Socket(fd);
}

// No copying or copy assignment.
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;

~Socket() { close(fd_); }

int fd() const { return fd_; }

private:
friend class ServerSocket;
explicit Socket(int fd) : fd_(fd) {}

int fd_;
};

class ServerSocket {
public:
explicit ServerSocket(uint16_t port) : port_(port), fd_(-1) {
struct sockaddr_in address;
int opt = 1;
int addrlen = sizeof(address);

// Creating socket file descriptor
fd_ = socket(AF_INET, SOCK_STREAM, 0);
if (fd_ == 0) {
perror("Socket failed");
throw std::runtime_error("Socket failed");
}

if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt))) {
perror("setsockopt");
throw std::runtime_error("setsockopt");
}

address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port);

if (bind(fd_, reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0) {
perror("bind failed");
throw std::runtime_error("bind failed");
}

if (listen(fd_, 1) < 0) {
perror("listen");
throw std::runtime_error("listen failed");
}
}

Socket Accept() const {
struct sockaddr_in address;
socklen_t addrlen = sizeof(address);
const int new_fd = accept(fd_, reinterpret_cast<struct sockaddr *>(&address), &addrlen);
if (new_fd < 0) {
perror("Accept failed");
throw std::runtime_error("Accept failed");
}
return Socket(new_fd);
}

~ServerSocket() { close(fd_); }

// No copying or copy assignment.
ServerSocket(const ServerSocket&) = delete;
ServerSocket& operator=(const ServerSocket&) = delete;

int fd() const { return fd_; }

private:
uint16_t port_;
int fd_;
};

class SentinelPipe {
public:
SentinelPipe() {
if (pipe(fd_) < 0) {
perror("Pipe failed.");
throw std::runtime_error("Pipe failed");
}
}

~SentinelPipe() {
for (int i = 0; i < 2; ++i) {
if (fd_[i] > 0) {
close(fd_[i]);
fd_[i] = -1;
}
}
}

SentinelPipe(const SentinelPipe&) = delete;
SentinelPipe& operator=(const SentinelPipe&) = delete;

int read_fd() const { return fd_[0]; }
int write_fd() const { return fd_[1]; }

private:
int fd_[2];
};

class Buffer {
public:
Buffer(size_t size) : buf_(nullptr) {
buf_ = new uint8_t[size];
}

~Buffer() {
if (buf_ == nullptr) return;
delete buf_;
buf_ = nullptr;
}

uint8_t* buffer() const { return buf_; }

private:
uint8_t* buf_;
};

std::function<void(int)> g_handle_signal;

void signal_wrapper(int signal) {
if (!g_handle_signal) return;
g_handle_signal(signal);
}

} // namespace

int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Proxies TCP connections.");
gflags::ParseCommandLineFlags(&argc, &argv, true);

if (FLAGS_proxy_to_host.empty()) {
std::cerr << "ERROR: Must provide a value for --proxy-to-host" << std::endl;
return 1;
}

// Workflow:
// - Start a socket listening for connections on `port`
// - Once we accept one connection, open a socket to the proxied-to host/port
// - Shuffle bytes to and from the two connections
// - Close the sockets on Ctrl-C or when there is an EOF

ServerSocket server(FLAGS_port);
std::cerr << "Listening for a connection on port " << FLAGS_port << std::endl;

const Socket to_client = server.Accept();
std::cerr << "Accepted client connection." << std::endl;

std::cerr << "Connecting to " << FLAGS_proxy_to_host << ":" << FLAGS_proxy_to_port << std::endl;
const Socket to_underlying = Socket::Connect(FLAGS_proxy_to_host, FLAGS_proxy_to_port);
std::cerr << "Connection succeeded." << std::endl;

// Handle early exit (Ctrl+C or SIGTERM).
SentinelPipe sentinel;
g_handle_signal = [&sentinel](int signal) {
char null_char = '\0';
write(sentinel.write_fd(), &null_char, sizeof(null_char));
};
std::signal(SIGINT, signal_wrapper);
std::signal(SIGTERM, signal_wrapper);

const size_t buffer_size = 4096;
Buffer client_to_underlying(buffer_size), underlying_to_client(buffer_size), scratch(buffer_size);

fd_set descriptors;
const int nfds = std::max(std::max(to_client.fd(), to_underlying.fd()), sentinel.read_fd()) + 1;
while (true) {
FD_ZERO(&descriptors);
FD_SET(to_client.fd(), &descriptors);
FD_SET(to_underlying.fd(), &descriptors);
FD_SET(sentinel.read_fd(), &descriptors);

const int result = select(nfds, &descriptors, nullptr, nullptr, nullptr);
if (result < 0) {
perror("Select");
break;
}

if (FD_ISSET(to_client.fd(), &descriptors)) {
// Forward client message to underlying.
const ssize_t bytes_read = read(to_client.fd(), client_to_underlying.buffer(), buffer_size);
if (bytes_read < 0) {
perror("Read from client");
break;
}

ssize_t left_to_write = bytes_read;
uint8_t* buffer = client_to_underlying.buffer();
while (left_to_write > 0) {
const ssize_t bytes_written = write(to_underlying.fd(), buffer, left_to_write);
if (bytes_written < 0) {
perror("Write to underlying");
break;
}
left_to_write -= bytes_written;
buffer += bytes_written;
}
}

if (FD_ISSET(to_underlying.fd(), &descriptors)) {
// Forward underlying message to client.
const ssize_t bytes_read = read(to_underlying.fd(), underlying_to_client.buffer(), buffer_size);
if (bytes_read < 0) {
perror("Read from underlying");
break;
}

ssize_t left_to_write = bytes_read;
uint8_t* buffer = underlying_to_client.buffer();
while (left_to_write > 0) {
const ssize_t bytes_written = write(to_client.fd(), buffer, left_to_write);
if (bytes_written < 0) {
perror("Write to client");
break;
}
left_to_write -= bytes_written;
buffer += bytes_written;
}
}

if (FD_ISSET(sentinel.read_fd(), &descriptors)) {
read(sentinel.read_fd(), scratch.buffer(), 1);
break;
}
}

std::cerr << "Done and exiting." << std::endl;
return 0;
}
33 changes: 33 additions & 0 deletions experiments/18-proxy/odbc_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.connection.connection import Connection
from brad.connection.factory import ConnectionFactory
from brad.connection.odbc_connection import OdbcConnection


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--address", type=str, required=True)
parser.add_argument("--port", type=str, required=True)
parser.add_argument("--physical-config-file", type=str, required=True)
args = parser.parse_args()

config = ConfigFile.load_from_physical_config(args.physical_config_file)
cstr = ConnectionFactory._pg_aurora_odbc_connection_string(
args.address,
args.port,
config.get_connection_details(Engine.Aurora),
schema_name=None,
)
cxn: Connection = OdbcConnection.connect_sync(cstr, autocommit=True, timeout_s=30)
cursor = cxn.cursor_sync()
cursor.execute_sync("SELECT 1")
print(cursor.fetchall_sync())

cxn.close_sync()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"pandas",
"scikit-learn==1.3.0",
"types-pytz",
"numpy",
"numpy==1.25.2",
"imbalanced-learn",
"redshift_connector",
"tabulate",
Expand Down