Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Jan 10, 2025
1 parent 48c91b2 commit 7f68453
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
15 changes: 11 additions & 4 deletions test_runner/regress/test_proxy_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,22 @@ async def test_websockets_pipelined(static_proxy: NeonProxy):

@pytest.mark.asyncio
async def test_websockets_tunneled(static_proxy: NeonProxy, port_distributor: PortDistributor):
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")

user = "ws_auth"
password = "ws"

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))

# Launch a tunnel service so that we can speak the websockets protocol to
# the proxy
tunnel_port = port_distributor.get_port()
tunnel_server = await websocket_tunnel.start_server(
"127.0.0.1",
tunnel_port,
f"wss://ep-static-test.neon.localtest.me:{static_proxy.external_http_port}",
"127.0.0.1",
static_proxy.external_http_port,
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
ssl_context,
)
log.info(f"websockets tunnel listening for connections on port {tunnel_port}")

Expand All @@ -229,7 +236,7 @@ async def run_tunnel():

# Ok, the tunnel is now running. Check that we can connect to the proxy's
# websocket interface, through the tunnel
tunnel_connstring = f"postgres://proxy:[email protected]:{tunnel_port}/postgres"
tunnel_connstring = f"postgres://{user}:{password}@127.0.0.1:{tunnel_port}/postgres"

log.info(f"connecting to {tunnel_connstring}")
conn = await asyncpg.connect(tunnel_connstring)
Expand Down
24 changes: 11 additions & 13 deletions test_runner/websocket_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def enable_verbose_logging():
logger.addHandler(logging.StreamHandler())


async def start_server(tcp_listen_host, tcp_listen_port, ws_url, ws_host, ws_port):
async def start_server(tcp_listen_host, tcp_listen_port, ws_url, ctx):
server = await asyncio.start_server(
lambda r, w: handle_client(r, w, ws_url, ws_host, ws_port), tcp_listen_host, tcp_listen_port
lambda r, w: handle_client(r, w, ws_url, ctx), tcp_listen_host, tcp_listen_port
)
return server

Expand All @@ -57,7 +57,7 @@ async def handle_tcp_to_websocket(tcp_reader, ws):
await ws.send(data)
except websockets.exceptions.ConnectionClosedError as e:
log.debug(f"connection closed: {e}")
except websockets.exceptions.ConnectionClosedOk:
except websockets.exceptions.ConnectionClosedOK:
log.debug("connection closed")
except Exception as e:
log.error(e)
Expand All @@ -70,21 +70,17 @@ async def handle_websocket_to_tcp(ws, tcp_writer):
await tcp_writer.drain()
except websockets.exceptions.ConnectionClosedError as e:
log.debug(f"connection closed: {e}")
except websockets.exceptions.ConnectionClosedOk:
except websockets.exceptions.ConnectionClosedOK:
log.debug("connection closed")
except Exception as e:
log.error(e)


async def handle_client(tcp_reader, tcp_writer, ws_url: str, ws_host: str, ws_port: int):
async def handle_client(tcp_reader, tcp_writer, ws_url: str, ctx: ssl.SSLContext):
try:
log.info("Received TCP connection. Connecting to websockets proxy.")

ctx = ssl.create_default_context(Purpose.SERVER_AUTH)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

async with websockets.connect(ws_url, ssl=ctx, host=ws_host, port=ws_port) as ws:
async with websockets.connect(ws_url, ssl=ctx) as ws:
try:
log.info("Connected to websockets proxy")

Expand Down Expand Up @@ -143,9 +139,11 @@ async def main():
if args.verbose:
enable_verbose_logging()

server = await start_server(
args.tcp_listen_addr, args.tcp_listen_port, args.ws_url, args.ws_host, args.ws_port
)
ctx = ssl.create_default_context(Purpose.SERVER_AUTH)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

server = await start_server(args.tcp_listen_addr, args.tcp_listen_port, args.ws_url, ctx)
print(
f"Listening for connections at {args.tcp_listen_addr}:{args.tcp_listen_port}, forwarding them to {args.ws_host}:{args.ws_port}"
)
Expand Down

0 comments on commit 7f68453

Please sign in to comment.