Skip to content

Commit

Permalink
feat: Use tokio-tungstenite-wasm for client ws connecting
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus23 committed Jun 20, 2024
1 parent 60b26c4 commit 36b8235
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 157 deletions.
56 changes: 53 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions iroh-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ strum = { version = "0.26.2", features = ["derive"] }
tungstenite = "0.23.0"
fastwebsockets = { git = "https://github.com/denoland/fastwebsockets", revision = "efc0788", features = ["upgrade", "unstable-split"] }
tokio-tungstenite = "0.23.1"
tokio-tungstenite-wasm = "0.3.1"

[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
netlink-packet-core = "0.7.0"
Expand Down
153 changes: 91 additions & 62 deletions iroh-net/src/relay/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
//! based on tailscale/derp/derp_client.go
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use anyhow::{anyhow, bail, ensure, Result};
use bytes::{Bytes, BytesMut};
use fastwebsockets::{WebSocketError, WebSocketRead, WebSocketWrite};
use futures_lite::StreamExt;
use futures_sink::Sink;
use futures_util::sink::SinkExt;
use futures_util::TryFutureExt;
use tokio::io::AsyncWrite;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{Stream, StreamExt, TryFutureExt};
use tokio::io::{AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::mpsc;
use tokio_tungstenite_wasm::{Message, WebSocketStream};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
use tracing::{debug, info_span, trace, Instrument};

Expand Down Expand Up @@ -67,12 +71,10 @@ impl ClientReceiver {
}
}

type RelayReader = FramedRead<MaybeTlsStreamReader, DerpCodec>;

#[derive(derive_more::Debug)]
pub struct InnerClient {
// our local addrs
local_addr: SocketAddr,
local_addr: Option<SocketAddr>,
/// Channel on which to communicate to the server. The associated [`mpsc::Receiver`] will close
/// if there is ever an error writing to the server.
writer_channel: mpsc::Sender<ClientWriterMessage>,
Expand Down Expand Up @@ -126,8 +128,8 @@ impl Client {
}

/// The local address that the [`Client`] is listening on.
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.inner.local_addr)
pub fn local_addr(&self) -> Option<SocketAddr> {
self.inner.local_addr
}

/// Whether or not this [`Client`] is closed.
Expand Down Expand Up @@ -205,44 +207,42 @@ enum ClientWriterMessage {
NotePreferred(bool),
/// Shutdown the writer
Shutdown,
/// Send arbitrary websocket frames
SendWsFrame(#[debug("fastwebsockets::Frame")] fastwebsockets::Frame<'static>),
}

/// Call [`ClientWriter::run`] to listen for messages to send to the client.
/// Should be used by the [`Client`]
///
/// Shutsdown when you send a [`ClientWriterMessage::Shutdown`], or if there is an error writing to
/// the server.
struct ClientWriter<W: AsyncWrite + Unpin + Send + 'static> {
struct ClientWriter<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> {
recv_msgs: mpsc::Receiver<ClientWriterMessage>,
writer: WebSocketWrite<W>,
writer: W,
rate_limiter: Option<RateLimiter>,
}

impl<W: AsyncWrite + Unpin + Send + 'static> ClientWriter<W> {
impl<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> ClientWriter<W> {
async fn run(mut self) -> Result<()> {
while let Some(msg) = self.recv_msgs.recv().await {
match msg {
ClientWriterMessage::Packet((key, bytes)) => {
send_packet_ws(&mut self.writer, &self.rate_limiter, key, bytes).await?;
send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
self.writer.flush();
}
ClientWriterMessage::Pong(data) => {
write_frame_ws(&mut self.writer, Frame::Pong { data }, None).await?;
write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
self.writer.flush();
}
ClientWriterMessage::Ping(data) => {
write_frame_ws(&mut self.writer, Frame::Ping { data }, None).await?;
write_frame(&mut self.writer, Frame::Ping { data }, None).await?;
self.writer.flush();
}
ClientWriterMessage::NotePreferred(preferred) => {
write_frame_ws(&mut self.writer, Frame::NotePreferred { preferred }, None)
.await?;
write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?;
self.writer.flush();
}
ClientWriterMessage::Shutdown => {
return Ok(());
}
ClientWriterMessage::SendWsFrame(frame) => {
self.writer.write_frame(frame).await?;
}
}
}

Expand All @@ -253,17 +253,74 @@ impl<W: AsyncWrite + Unpin + Send + 'static> ClientWriter<W> {
/// The Builder returns a [`Client`] starts a [`ClientWriter`] run task.
pub struct ClientBuilder {
secret_key: SecretKey,
reader: WebSocketRead<MaybeTlsStreamReader>,
writer: WebSocketWrite<MaybeTlsStreamWriter>,
local_addr: SocketAddr,
reader: RelayConnReader,
writer: RelayConnWriter,
local_addr: Option<SocketAddr>,
}

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnReader {
Ws(#[debug("SplitStream<WebSocketStream>")] SplitStream<WebSocketStream>),
}

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnWriter {
Ws(
#[debug("SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>")]
SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>,
),
}

fn tung_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
}

impl Stream for RelayConnReader {
type Item = Result<Frame>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Frame::from_wasm_ws_message),
}
}
}

impl Sink<Frame> for RelayConnWriter {
type Error = std::io::Error;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(tung_to_io_err),
}
}

fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
match *self {
Self::Ws(ref mut ws) => ws
.start_send_unpin(item.into_wasm_ws_message()?)
.map_err(tung_to_io_err),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(tung_to_io_err),
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(tung_to_io_err),
}
}
}

impl ClientBuilder {
pub fn new(
secret_key: SecretKey,
local_addr: SocketAddr,
reader: WebSocketRead<MaybeTlsStreamReader>,
writer: WebSocketWrite<MaybeTlsStreamWriter>,
local_addr: Option<SocketAddr>,
reader: RelayConnReader,
writer: RelayConnWriter,
) -> Self {
Self {
secret_key,
Expand All @@ -279,7 +336,7 @@ impl ClientBuilder {
version: PROTOCOL_VERSION,
};
debug!("server_handshake: sending client_key: {:?}", &client_info);
crate::relay::codec::send_client_key_ws(&mut self.writer, &self.secret_key, &client_info)
crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info)
.await?;

// TODO: add some actual configuration
Expand Down Expand Up @@ -311,46 +368,18 @@ impl ClientBuilder {
let (reader_sender, reader_recv) = mpsc::channel(PER_CLIENT_READ_QUEUE_DEPTH);
let writer_sender2 = writer_sender.clone();
let reader_task = tokio::task::spawn(async move {
let mut send_fn = |mut frame: fastwebsockets::Frame<'_>| {
frame.unmask();
let frame: fastwebsockets::Frame<'static> = fastwebsockets::Frame::new(
frame.fin,
frame.opcode,
None,
fastwebsockets::Payload::Owned(frame.payload.to_vec()),
);
writer_sender2
.send(ClientWriterMessage::SendWsFrame(frame))
.map_err(|e| WebSocketError::SendError(e.into()))
};
loop {
let frame =
tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.read_frame(&mut send_fn))
.await;
let frame = tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.next()).await;
let res = match frame {
Ok(Ok(frame)) => {
// TODO(matheus23): Handle partial frames :S
if frame.opcode != fastwebsockets::OpCode::Binary {
tracing::warn!(?frame.opcode, "Ignoring frame with opcode != Binary");
continue;
} else {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(frame.payload.as_ref()); // TODO(matheus23): Slow for now
if let Ok(Some(frame)) = DerpCodec.decode(&mut bytes) {
process_incoming_frame(frame)
} else {
Err(anyhow::anyhow!("Failed to read frame"))
}
}
Ok(Some(Ok(frame))) => process_incoming_frame(frame),
Ok(Some(Err(err))) => {
// Error processing incoming messages
Err(err)
}
Ok(Err(WebSocketError::ConnectionClosed)) => {
Ok(None) => {
// EOF
Err(anyhow::anyhow!("EOF: reader stream ended"))
}
Ok(Err(err)) => {
// Error processing incoming messages
Err(err.into())
}
Err(err) => {
// Timeout
Err(err.into())
Expand Down
Loading

0 comments on commit 36b8235

Please sign in to comment.