Skip to content

Commit

Permalink
Check in experimental byte-level proxying code (#512)
Browse files Browse the repository at this point in the history
For high level experiments.
  • Loading branch information
geoffxy authored Jun 24, 2024
1 parent 9683936 commit 8e3199f
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 1 deletion.
6 changes: 6 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
cmake_minimum_required(VERSION 3.16)

option(BRAD_BUILD_EXPERIMENTAL OFF "Set to build the experimental code.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
Expand All @@ -16,6 +18,10 @@ find_package(Boost REQUIRED)

add_subdirectory(third_party)

if(BRAD_BUILD_EXPERIMENTAL)
add_subdirectory(experimental)
endif()

add_library(sqlite_server_lib OBJECT
sqlite_server/sqlite_server.cc
sqlite_server/sqlite_sql_info.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_executable(proxy_socket proxy_socket.cc)
target_link_libraries(proxy_socket PRIVATE gflags)
276 changes: 276 additions & 0 deletions cpp/experimental/proxy_socket.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
#include <iostream>
#include <stdexcept>
#include <csignal>
#include <functional>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>

#include <gflags/gflags.h>

DEFINE_int32(port, 31337, "Port that this server should listen on.");

DEFINE_int32(proxy_to_port, 5439, "Port that this server should proxy its connection to.");
DEFINE_string(proxy_to_host, "", "The host that this server should proxy its connection to.");

namespace {

class Socket {
public:
static Socket Connect(const std::string& host, const uint16_t port) {
const int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
perror("Socket failed.");
throw std::runtime_error("Socket failed.");
}

struct sockaddr_in serv_addr;
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(port);

if(inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) < 0) {
perror("Host conversion.");
throw std::runtime_error("Host conversion.");
}

if (connect(fd, reinterpret_cast<struct sockaddr *>(&serv_addr), sizeof(serv_addr)) < 0) {
perror("Connect failed");
throw std::runtime_error("Connect failed.");
}

return Socket(fd);
}

// No copying or copy assignment.
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;

~Socket() { close(fd_); }

int fd() const { return fd_; }

private:
friend class ServerSocket;
explicit Socket(int fd) : fd_(fd) {}

int fd_;
};

class ServerSocket {
public:
explicit ServerSocket(uint16_t port) : port_(port), fd_(-1) {
struct sockaddr_in address;
int opt = 1;
int addrlen = sizeof(address);

// Creating socket file descriptor
fd_ = socket(AF_INET, SOCK_STREAM, 0);
if (fd_ == 0) {
perror("Socket failed");
throw std::runtime_error("Socket failed");
}

if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt))) {
perror("setsockopt");
throw std::runtime_error("setsockopt");
}

address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port);

if (bind(fd_, reinterpret_cast<struct sockaddr *>(&address), sizeof(address)) < 0) {
perror("bind failed");
throw std::runtime_error("bind failed");
}

if (listen(fd_, 1) < 0) {
perror("listen");
throw std::runtime_error("listen failed");
}
}

Socket Accept() const {
struct sockaddr_in address;
socklen_t addrlen = sizeof(address);
const int new_fd = accept(fd_, reinterpret_cast<struct sockaddr *>(&address), &addrlen);
if (new_fd < 0) {
perror("Accept failed");
throw std::runtime_error("Accept failed");
}
return Socket(new_fd);
}

~ServerSocket() { close(fd_); }

// No copying or copy assignment.
ServerSocket(const ServerSocket&) = delete;
ServerSocket& operator=(const ServerSocket&) = delete;

int fd() const { return fd_; }

private:
uint16_t port_;
int fd_;
};

class SentinelPipe {
public:
SentinelPipe() {
if (pipe(fd_) < 0) {
perror("Pipe failed.");
throw std::runtime_error("Pipe failed");
}
}

~SentinelPipe() {
for (int i = 0; i < 2; ++i) {
if (fd_[i] > 0) {
close(fd_[i]);
fd_[i] = -1;
}
}
}

SentinelPipe(const SentinelPipe&) = delete;
SentinelPipe& operator=(const SentinelPipe&) = delete;

int read_fd() const { return fd_[0]; }
int write_fd() const { return fd_[1]; }

private:
int fd_[2];
};

class Buffer {
public:
Buffer(size_t size) : buf_(nullptr) {
buf_ = new uint8_t[size];
}

~Buffer() {
if (buf_ == nullptr) return;
delete buf_;
buf_ = nullptr;
}

uint8_t* buffer() const { return buf_; }

private:
uint8_t* buf_;
};

std::function<void(int)> g_handle_signal;

void signal_wrapper(int signal) {
if (!g_handle_signal) return;
g_handle_signal(signal);
}

} // namespace

int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Proxies TCP connections.");
gflags::ParseCommandLineFlags(&argc, &argv, true);

if (FLAGS_proxy_to_host.empty()) {
std::cerr << "ERROR: Must provide a value for --proxy-to-host" << std::endl;
return 1;
}

// Workflow:
// - Start a socket listening for connections on `port`
// - Once we accept one connection, open a socket to the proxied-to host/port
// - Shuffle bytes to and from the two connections
// - Close the sockets on Ctrl-C or when there is an EOF

ServerSocket server(FLAGS_port);
std::cerr << "Listening for a connection on port " << FLAGS_port << std::endl;

const Socket to_client = server.Accept();
std::cerr << "Accepted client connection." << std::endl;

std::cerr << "Connecting to " << FLAGS_proxy_to_host << ":" << FLAGS_proxy_to_port << std::endl;
const Socket to_underlying = Socket::Connect(FLAGS_proxy_to_host, FLAGS_proxy_to_port);
std::cerr << "Connection succeeded." << std::endl;

// Handle early exit (Ctrl+C or SIGTERM).
SentinelPipe sentinel;
g_handle_signal = [&sentinel](int signal) {
char null_char = '\0';
write(sentinel.write_fd(), &null_char, sizeof(null_char));
};
std::signal(SIGINT, signal_wrapper);
std::signal(SIGTERM, signal_wrapper);

const size_t buffer_size = 4096;
Buffer client_to_underlying(buffer_size), underlying_to_client(buffer_size), scratch(buffer_size);

fd_set descriptors;
const int nfds = std::max(std::max(to_client.fd(), to_underlying.fd()), sentinel.read_fd()) + 1;
while (true) {
FD_ZERO(&descriptors);
FD_SET(to_client.fd(), &descriptors);
FD_SET(to_underlying.fd(), &descriptors);
FD_SET(sentinel.read_fd(), &descriptors);

const int result = select(nfds, &descriptors, nullptr, nullptr, nullptr);
if (result < 0) {
perror("Select");
break;
}

if (FD_ISSET(to_client.fd(), &descriptors)) {
// Forward client message to underlying.
const ssize_t bytes_read = read(to_client.fd(), client_to_underlying.buffer(), buffer_size);
if (bytes_read < 0) {
perror("Read from client");
break;
}

ssize_t left_to_write = bytes_read;
uint8_t* buffer = client_to_underlying.buffer();
while (left_to_write > 0) {
const ssize_t bytes_written = write(to_underlying.fd(), buffer, left_to_write);
if (bytes_written < 0) {
perror("Write to underlying");
break;
}
left_to_write -= bytes_written;
buffer += bytes_written;
}
}

if (FD_ISSET(to_underlying.fd(), &descriptors)) {
// Forward underlying message to client.
const ssize_t bytes_read = read(to_underlying.fd(), underlying_to_client.buffer(), buffer_size);
if (bytes_read < 0) {
perror("Read from underlying");
break;
}

ssize_t left_to_write = bytes_read;
uint8_t* buffer = underlying_to_client.buffer();
while (left_to_write > 0) {
const ssize_t bytes_written = write(to_client.fd(), buffer, left_to_write);
if (bytes_written < 0) {
perror("Write to client");
break;
}
left_to_write -= bytes_written;
buffer += bytes_written;
}
}

if (FD_ISSET(sentinel.read_fd(), &descriptors)) {
read(sentinel.read_fd(), scratch.buffer(), 1);
break;
}
}

std::cerr << "Done and exiting." << std::endl;
return 0;
}
33 changes: 33 additions & 0 deletions experiments/18-proxy/odbc_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.connection.connection import Connection
from brad.connection.factory import ConnectionFactory
from brad.connection.odbc_connection import OdbcConnection


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--address", type=str, required=True)
parser.add_argument("--port", type=str, required=True)
parser.add_argument("--physical-config-file", type=str, required=True)
args = parser.parse_args()

config = ConfigFile.load_from_physical_config(args.physical_config_file)
cstr = ConnectionFactory._pg_aurora_odbc_connection_string(
args.address,
args.port,
config.get_connection_details(Engine.Aurora),
schema_name=None,
)
cxn: Connection = OdbcConnection.connect_sync(cstr, autocommit=True, timeout_s=30)
cursor = cxn.cursor_sync()
cursor.execute_sync("SELECT 1")
print(cursor.fetchall_sync())

cxn.close_sync()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"pandas",
"scikit-learn==1.3.0",
"types-pytz",
"numpy",
"numpy==1.25.2",
"imbalanced-learn",
"redshift_connector",
"tabulate",
Expand Down

0 comments on commit 8e3199f

Please sign in to comment.