Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a websockets tunnel and a test for the proxy's websockets support. #3823

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions test_runner/regress/test_proxy_websockets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

import asyncio
import ssl

import asyncpg
import pytest
import websocket_tunnel
import websockets
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonProxy
from fixtures.port_distributor import PortDistributor


@pytest.mark.asyncio
Expand Down Expand Up @@ -196,3 +201,53 @@ async def test_websockets_pipelined(static_proxy: NeonProxy):
# close
await websocket.send(b"X\x00\x00\x00\x04")
await websocket.wait_closed()


@pytest.mark.asyncio
async def test_websockets_tunneled(static_proxy: NeonProxy, port_distributor: PortDistributor):
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")

user = "ws_auth"
password = "ws"

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))

# Launch a tunnel service so that we can speak the websockets protocol to
# the proxy
tunnel_port = port_distributor.get_port()
tunnel_server = await websocket_tunnel.start_server(
"127.0.0.1",
tunnel_port,
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
ssl_context,
)
log.info(f"websockets tunnel listening for connections on port {tunnel_port}")

async with tunnel_server:

async def run_tunnel():
try:
async with tunnel_server:
await tunnel_server.serve_forever()
except Exception as e:
log.error(f"Error in tunnel task: {e}")

tunnel_task = asyncio.create_task(run_tunnel())

# Ok, the tunnel is now running. Check that we can connect to the proxy's
# websocket interface, through the tunnel
tunnel_connstring = f"postgres://{user}:{password}@127.0.0.1:{tunnel_port}/postgres"

log.info(f"connecting to {tunnel_connstring}")
conn = await asyncpg.connect(tunnel_connstring)
res = await conn.fetchval("SELECT 123")
assert res == 123
await conn.close()
log.info("Ran a query successfully through the tunnel")

tunnel_server.close()
try:
await tunnel_task
except asyncio.CancelledError:
pass
155 changes: 155 additions & 0 deletions test_runner/websocket_tunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#!/usr/bin/env python3
#
# This program helps to test the WebSocket tunneling in proxy. It listens for a TCP
# connection on a port, and when you connect to it, it opens a websocket connection,
# and forwards all the traffic to the websocket connection, wrapped in WebSocket binary
# frames.
#
# This is used in the test_proxy::test_websockets test, but it is handy for manual testing too.
#
# Usage for manual testing:
#
# ## Launch Posgres on port 3000:
# postgres -D data -p3000
#
# ## Launch proxy with WSS enabled:
# openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj '/CN=*.neon.localtest.me'
# ./target/debug/proxy --wss 127.0.0.1:40433 --http 127.0.0.1:28080 --mgmt 127.0.0.1:9099 --proxy 127.0.0.1:4433 --tls-key server.key --tls-cert server.crt --auth-backend postgres
#
# ## Launch the tunnel:
#
# poetry run ./test_runner/websocket_tunnel.py --ws-port 40433 --ws-url "wss://ep-test.neon.localtest.me"
#
# ## Now you can connect with psql:
# psql "postgresql://heikki@localhost:40433/postgres"
#

import argparse
import asyncio
import logging
import ssl
from ssl import Purpose

import websockets
from fixtures.log_helper import log

# Enable verbose logging of all the traffic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nit: it seems like we a couple of extra lines here


def enable_verbose_logging():
logger = logging.getLogger("websockets")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())


async def start_server(tcp_listen_host, tcp_listen_port, ws_url, ctx):
server = await asyncio.start_server(
lambda r, w: handle_client(r, w, ws_url, ctx), tcp_listen_host, tcp_listen_port
)
return server


async def handle_tcp_to_websocket(tcp_reader, ws):
try:
while not tcp_reader.at_eof():
data = await tcp_reader.read(1024)

await ws.send(data)
except websockets.exceptions.ConnectionClosedError as e:
log.debug(f"connection closed: {e}")
except websockets.exceptions.ConnectionClosedOK:
log.debug("connection closed")
except Exception as e:
log.error(e)


async def handle_websocket_to_tcp(ws, tcp_writer):
try:
async for message in ws:
tcp_writer.write(message)
await tcp_writer.drain()
except websockets.exceptions.ConnectionClosedError as e:
log.debug(f"connection closed: {e}")
except websockets.exceptions.ConnectionClosedOK:
log.debug("connection closed")
except Exception as e:
log.error(e)


async def handle_client(tcp_reader, tcp_writer, ws_url: str, ctx: ssl.SSLContext):
try:
log.info("Received TCP connection. Connecting to websockets proxy.")

async with websockets.connect(ws_url, ssl=ctx) as ws:
try:
log.info("Connected to websockets proxy")

async with asyncio.TaskGroup() as tg:
task1 = tg.create_task(handle_tcp_to_websocket(tcp_reader, ws))
task2 = tg.create_task(handle_websocket_to_tcp(ws, tcp_writer))

done, pending = await asyncio.wait(
[task1, task2], return_when=asyncio.FIRST_COMPLETED
)
tcp_writer.close()
await ws.close()

except* Exception as ex:
log.error(ex.exceptions)
except Exception as e:
log.error(e)


async def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tcp-listen-addr",
default="localhost",
help="TCP addr to listen on",
)
parser.add_argument(
"--tcp-listen-port",
default="40444",
help="TCP port to listen on",
)

parser.add_argument(
"--ws-url",
default="wss://localhost/",
help="websocket URL to connect to. This determines the Host header sent to the server",
)
parser.add_argument(
"--ws-host",
default="127.0.0.1",
help="websockets host to connect to",
)
parser.add_argument(
"--ws-port",
type=int,
default=443,
help="websockets port to connect to",
)
parser.add_argument(
"--verbose",
action="store_true",
help="enable verbose logging",
)
args = parser.parse_args()

if args.verbose:
enable_verbose_logging()

ctx = ssl.create_default_context(Purpose.SERVER_AUTH)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

server = await start_server(args.tcp_listen_addr, args.tcp_listen_port, args.ws_url, ctx)
print(
f"Listening for connections at {args.tcp_listen_addr}:{args.tcp_listen_port}, forwarding them to {args.ws_host}:{args.ws_port}"
)
async with server:
await server.serve_forever()


if __name__ == "__main__":
asyncio.run(main())
Loading