diff --git a/Cargo.lock b/Cargo.lock index 1f5f1790d67..0444500570c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2902,12 +2902,13 @@ dependencies = [ "tokio", "tokio-rustls 0.24.1", "tokio-rustls-acme", - "tokio-tungstenite", + "tokio-tungstenite 0.23.1", + "tokio-tungstenite-wasm", "tokio-util", "toml", "tracing", "tracing-subscriber", - "tungstenite", + "tungstenite 0.23.0", "url", "watchable", "webpki-roots 0.25.4", @@ -5890,6 +5891,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", +] + [[package]] name = "tokio-tungstenite" version = "0.23.1" @@ -5899,7 +5912,25 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.23.0", +] + +[[package]] +name = "tokio-tungstenite-wasm" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e57a65894797a018b28345fa298a00c450a574aa9671e50b18218a6292a55ac" +dependencies = [ + "futures-channel", + "futures-util", + "http 1.1.0", + "httparse", + "js-sys", + "thiserror", + "tokio", + "tokio-tungstenite 0.21.0", + "wasm-bindgen", + "web-sys", ] [[package]] @@ -6125,6 +6156,25 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.23.0" diff --git a/iroh-net/Cargo.toml b/iroh-net/Cargo.toml index f54f58af850..52220a372de 100644 --- a/iroh-net/Cargo.toml +++ b/iroh-net/Cargo.toml @@ -91,6 +91,7 @@ strum = { version = "0.26.2", features = ["derive"] } tungstenite = "0.23.0" fastwebsockets = { git = "https://github.com/denoland/fastwebsockets", revision = "efc0788", features = ["upgrade", "unstable-split"] } tokio-tungstenite = "0.23.1" +tokio-tungstenite-wasm = "0.3.1" [target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] netlink-packet-core = "0.7.0" diff --git a/iroh-net/src/relay/client.rs b/iroh-net/src/relay/client.rs index 158217f163c..13290d10449 100644 --- a/iroh-net/src/relay/client.rs +++ b/iroh-net/src/relay/client.rs @@ -1,17 +1,21 @@ //! based on tailscale/derp/derp_client.go +use std::marker::PhantomData; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use anyhow::{anyhow, bail, ensure, Result}; use bytes::{Bytes, BytesMut}; use fastwebsockets::{WebSocketError, WebSocketRead, WebSocketWrite}; -use futures_lite::StreamExt; use futures_sink::Sink; use futures_util::sink::SinkExt; -use futures_util::TryFutureExt; -use tokio::io::AsyncWrite; +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{Stream, StreamExt, TryFutureExt}; +use tokio::io::{AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::mpsc; +use tokio_tungstenite_wasm::{Message, WebSocketStream}; use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite}; use tracing::{debug, info_span, trace, Instrument}; @@ -67,12 +71,10 @@ impl ClientReceiver { } } -type RelayReader = FramedRead; - #[derive(derive_more::Debug)] pub struct InnerClient { // our local addrs - local_addr: SocketAddr, + local_addr: Option, /// Channel on which to communicate to the server. The associated [`mpsc::Receiver`] will close /// if there is ever an error writing to the server. writer_channel: mpsc::Sender, @@ -126,8 +128,8 @@ impl Client { } /// The local address that the [`Client`] is listening on. - pub fn local_addr(&self) -> Result { - Ok(self.inner.local_addr) + pub fn local_addr(&self) -> Option { + self.inner.local_addr } /// Whether or not this [`Client`] is closed. @@ -205,8 +207,6 @@ enum ClientWriterMessage { NotePreferred(bool), /// Shutdown the writer Shutdown, - /// Send arbitrary websocket frames - SendWsFrame(#[debug("fastwebsockets::Frame")] fastwebsockets::Frame<'static>), } /// Call [`ClientWriter::run`] to listen for messages to send to the client. @@ -214,35 +214,35 @@ enum ClientWriterMessage { /// /// Shutsdown when you send a [`ClientWriterMessage::Shutdown`], or if there is an error writing to /// the server. -struct ClientWriter { +struct ClientWriter + Unpin + 'static> { recv_msgs: mpsc::Receiver, - writer: WebSocketWrite, + writer: W, rate_limiter: Option, } -impl ClientWriter { +impl + Unpin + 'static> ClientWriter { async fn run(mut self) -> Result<()> { while let Some(msg) = self.recv_msgs.recv().await { match msg { ClientWriterMessage::Packet((key, bytes)) => { - send_packet_ws(&mut self.writer, &self.rate_limiter, key, bytes).await?; + send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?; + self.writer.flush(); } ClientWriterMessage::Pong(data) => { - write_frame_ws(&mut self.writer, Frame::Pong { data }, None).await?; + write_frame(&mut self.writer, Frame::Pong { data }, None).await?; + self.writer.flush(); } ClientWriterMessage::Ping(data) => { - write_frame_ws(&mut self.writer, Frame::Ping { data }, None).await?; + write_frame(&mut self.writer, Frame::Ping { data }, None).await?; + self.writer.flush(); } ClientWriterMessage::NotePreferred(preferred) => { - write_frame_ws(&mut self.writer, Frame::NotePreferred { preferred }, None) - .await?; + write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?; + self.writer.flush(); } ClientWriterMessage::Shutdown => { return Ok(()); } - ClientWriterMessage::SendWsFrame(frame) => { - self.writer.write_frame(frame).await?; - } } } @@ -253,17 +253,74 @@ impl ClientWriter { /// The Builder returns a [`Client`] starts a [`ClientWriter`] run task. pub struct ClientBuilder { secret_key: SecretKey, - reader: WebSocketRead, - writer: WebSocketWrite, - local_addr: SocketAddr, + reader: RelayConnReader, + writer: RelayConnWriter, + local_addr: Option, +} + +#[derive(derive_more::Debug)] +pub(crate) enum RelayConnReader { + Ws(#[debug("SplitStream")] SplitStream), +} + +#[derive(derive_more::Debug)] +pub(crate) enum RelayConnWriter { + Ws( + #[debug("SplitSink")] + SplitSink, + ), +} + +fn tung_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) +} + +impl Stream for RelayConnReader { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Frame::from_wasm_ws_message), + } + } +} + +impl Sink for RelayConnWriter { + type Error = std::io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(tung_to_io_err), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { + match *self { + Self::Ws(ref mut ws) => ws + .start_send_unpin(item.into_wasm_ws_message()?) + .map_err(tung_to_io_err), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(tung_to_io_err), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(tung_to_io_err), + } + } } impl ClientBuilder { pub fn new( secret_key: SecretKey, - local_addr: SocketAddr, - reader: WebSocketRead, - writer: WebSocketWrite, + local_addr: Option, + reader: RelayConnReader, + writer: RelayConnWriter, ) -> Self { Self { secret_key, @@ -279,7 +336,7 @@ impl ClientBuilder { version: PROTOCOL_VERSION, }; debug!("server_handshake: sending client_key: {:?}", &client_info); - crate::relay::codec::send_client_key_ws(&mut self.writer, &self.secret_key, &client_info) + crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info) .await?; // TODO: add some actual configuration @@ -311,46 +368,18 @@ impl ClientBuilder { let (reader_sender, reader_recv) = mpsc::channel(PER_CLIENT_READ_QUEUE_DEPTH); let writer_sender2 = writer_sender.clone(); let reader_task = tokio::task::spawn(async move { - let mut send_fn = |mut frame: fastwebsockets::Frame<'_>| { - frame.unmask(); - let frame: fastwebsockets::Frame<'static> = fastwebsockets::Frame::new( - frame.fin, - frame.opcode, - None, - fastwebsockets::Payload::Owned(frame.payload.to_vec()), - ); - writer_sender2 - .send(ClientWriterMessage::SendWsFrame(frame)) - .map_err(|e| WebSocketError::SendError(e.into())) - }; loop { - let frame = - tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.read_frame(&mut send_fn)) - .await; + let frame = tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.next()).await; let res = match frame { - Ok(Ok(frame)) => { - // TODO(matheus23): Handle partial frames :S - if frame.opcode != fastwebsockets::OpCode::Binary { - tracing::warn!(?frame.opcode, "Ignoring frame with opcode != Binary"); - continue; - } else { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(frame.payload.as_ref()); // TODO(matheus23): Slow for now - if let Ok(Some(frame)) = DerpCodec.decode(&mut bytes) { - process_incoming_frame(frame) - } else { - Err(anyhow::anyhow!("Failed to read frame")) - } - } + Ok(Some(Ok(frame))) => process_incoming_frame(frame), + Ok(Some(Err(err))) => { + // Error processing incoming messages + Err(err) } - Ok(Err(WebSocketError::ConnectionClosed)) => { + Ok(None) => { // EOF Err(anyhow::anyhow!("EOF: reader stream ended")) } - Ok(Err(err)) => { - // Error processing incoming messages - Err(err.into()) - } Err(err) => { // Timeout Err(err.into()) diff --git a/iroh-net/src/relay/codec.rs b/iroh-net/src/relay/codec.rs index 80f39e46040..d34a8abe7a2 100644 --- a/iroh-net/src/relay/codec.rs +++ b/iroh-net/src/relay/codec.rs @@ -315,6 +315,58 @@ impl Frame { } } + pub fn into_ws_message(self) -> std::io::Result { + let mut bytes = bytes::BytesMut::new(); + DerpCodec.encode(self, &mut bytes)?; + Ok(tungstenite::Message::binary(bytes)) + } + + pub fn into_wasm_ws_message(self) -> std::io::Result { + let mut bytes = bytes::BytesMut::new(); + DerpCodec.encode(self, &mut bytes)?; + Ok(tokio_tungstenite_wasm::Message::binary(bytes)) + } + + pub fn from_ws_message( + msg: Option>, + ) -> Option> { + match msg { + Some(Ok(tungstenite::Message::Binary(vec))) => { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&vec); // TODO(matheus23) this is slow/weird + Some(DerpCodec.decode(&mut bytes).and_then(|option| { + option.ok_or_else(|| anyhow::anyhow!("incomplete frame in websocket message")) + })) + } + Some(Ok(msg)) => { + tracing::warn!(?msg, "Got msg of unsupported type, skipping."); + None + } + Some(Err(e)) => Some(Err(e.into())), + None => None, + } + } + + pub fn from_wasm_ws_message( + msg: Option>, + ) -> Option> { + match msg { + Some(Ok(tokio_tungstenite_wasm::Message::Binary(vec))) => { + let mut bytes = BytesMut::new(); + bytes.extend_from_slice(&vec); // TODO(matheus23) this is slow/weird + Some(DerpCodec.decode(&mut bytes).and_then(|option| { + option.ok_or_else(|| anyhow::anyhow!("incomplete frame in websocket message")) + })) + } + Some(Ok(msg)) => { + tracing::warn!(?msg, "Got msg of unsupported type, skipping."); + None + } + Some(Err(e)) => Some(Err(e.into())), + None => None, + } + } + /// Writes it self to the given buffer. fn write_to(&self, dst: &mut BytesMut) { match self { diff --git a/iroh-net/src/relay/http/client.rs b/iroh-net/src/relay/http/client.rs index 592c1d6e6f0..4baa0c7fbe7 100644 --- a/iroh-net/src/relay/http/client.rs +++ b/iroh-net/src/relay/http/client.rs @@ -9,6 +9,7 @@ use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use bytes::Bytes; use fastwebsockets::WebSocket; use futures_lite::future::Boxed as BoxFuture; +use futures_util::StreamExt; use http_body_util::Empty; use hyper::body::Incoming; use hyper::header::UPGRADE; @@ -22,11 +23,13 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinSet; use tokio::time::Instant; +use tokio_tungstenite_wasm::WebSocketStream; use tracing::{debug, error, info_span, trace, warn, Instrument}; use url::Url; use crate::dns::{DnsResolver, ResolverExt}; use crate::key::{PublicKey, SecretKey}; +use crate::relay::client::{RelayConnReader, RelayConnWriter}; use crate::relay::http::streams::{downcast_upgrade, MaybeTlsStream}; use crate::relay::http::WEBSOCKET_UPGRADE_PROTOCOL; use crate::relay::RelayUrl; @@ -122,6 +125,9 @@ pub enum ClientError { /// The inner actor is gone, likely means things are shutdown. #[error("actor gone")] ActorGone, + /// There was an error related to websockets + #[error("websocket error")] + WebsocketError(#[from] tokio_tungstenite_wasm::Error), } /// An HTTP Relay client. @@ -582,53 +588,62 @@ impl Actor { } async fn connect_0(&self) -> Result<(RelayClient, RelayClientReceiver), ClientError> { - let tcp_stream = self.dial_url().await?; - - let local_addr = tcp_stream - .local_addr() - .map_err(|e| ClientError::NoLocalAddr(e.to_string()))?; - - debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); - - let response = if self.use_https() { - debug!("Starting TLS handshake"); - let hostname = self - .tls_servername() - .ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?; - let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?; - debug!("tls_connector connect success"); - Self::start_upgrade(tls_stream).await? - } else { - debug!("Starting handshake"); - Self::start_upgrade(tcp_stream).await? - }; - - if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - error!( - "expected status 101 SWITCHING_PROTOCOLS, got: {}", - response.status() - ); - return Err(ClientError::UnexpectedStatusCode( - hyper::StatusCode::SWITCHING_PROTOCOLS, - response.status(), - )); - } - - debug!("starting upgrade"); - let upgraded = match hyper::upgrade::on(response).await { - Ok(upgraded) => upgraded, - Err(err) => { - warn!("upgrade failed: {:#}", err); - return Err(ClientError::Hyper(err)); - } - }; - - debug!("connection upgraded"); - let (reader, writer) = - downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?; - - let (reader, writer) = - fastwebsockets::after_handshake_split(reader, writer, fastwebsockets::Role::Client); + // let tcp_stream = self.dial_url().await?; + + // let local_addr = tcp_stream + // .local_addr() + // .map_err(|e| ClientError::NoLocalAddr(e.to_string()))?; + + // debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); + + // let response = if self.use_https() { + // debug!("Starting TLS handshake"); + // let hostname = self + // .tls_servername() + // .ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?; + // let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?; + // debug!("tls_connector connect success"); + // Self::start_upgrade(tls_stream).await? + // } else { + // debug!("Starting handshake"); + // Self::start_upgrade(tcp_stream).await? + // }; + + // if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { + // error!( + // "expected status 101 SWITCHING_PROTOCOLS, got: {}", + // response.status() + // ); + // return Err(ClientError::UnexpectedStatusCode( + // hyper::StatusCode::SWITCHING_PROTOCOLS, + // response.status(), + // )); + // } + + // debug!("starting upgrade"); + // let upgraded = match hyper::upgrade::on(response).await { + // Ok(upgraded) => upgraded, + // Err(err) => { + // warn!("upgrade failed: {:#}", err); + // return Err(ClientError::Hyper(err)); + // } + // }; + + // debug!("connection upgraded"); + // let (reader, writer) = + // downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?; + + // let (reader, writer) = + // fastwebsockets::after_handshake_split(reader, writer, fastwebsockets::Role::Client); + + let (writer, reader) = tokio_tungstenite_wasm::connect(self.url.as_str()) + .await? + .split(); + + let reader = RelayConnReader::Ws(reader); + let writer = RelayConnWriter::Ws(writer); + + let local_addr = None; let (relay_client, receiver) = RelayClientBuilder::new(self.secret_key.clone(), local_addr, reader, writer) @@ -705,12 +720,10 @@ impl Actor { return None; } if let Some((ref client, _)) = self.relay_client { - match client.local_addr() { - Ok(addr) => return Some(addr), - _ => return None, - } + client.local_addr() + } else { + None } - None } async fn ping(&mut self, s: oneshot::Sender>) { diff --git a/iroh-net/src/relay/server.rs b/iroh-net/src/relay/server.rs index baccf7adb41..48604f49fcc 100644 --- a/iroh-net/src/relay/server.rs +++ b/iroh-net/src/relay/server.rs @@ -185,34 +185,8 @@ pub(crate) enum RelayIo { Ws(WebSocketStream), } -impl RelayIo { - fn frame_to_message(item: Frame) -> std::io::Result { - let mut bytes = bytes::BytesMut::new(); - DerpCodec.encode(item, &mut bytes)?; - Ok(Message::binary(bytes)) - } - - fn message_to_frame(msg: Option>) -> Option> { - match msg { - Some(Ok(Message::Binary(vec))) => { - let mut bytes = BytesMut::new(); - bytes.extend_from_slice(&vec); // TODO(matheus23) this is slow/weird - Some(DerpCodec.decode(&mut bytes).and_then(|option| { - option.ok_or_else(|| anyhow::anyhow!("incomplete frame in websocket message")) - })) - } - Some(Ok(msg)) => { - tracing::warn!(?msg, "Got msg of unsupported type, skipping."); - None - } - Some(Err(e)) => Some(Err(e.into())), - None => None, - } - } - - fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) - } +fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) } impl Sink for RelayIo { @@ -221,7 +195,7 @@ impl Sink for RelayIo { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { Self::Derp(ref mut framed) => framed.poll_ready_unpin(cx), - Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(Self::tung_to_io_err), + Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(tung_to_io_err), } } @@ -229,22 +203,22 @@ impl Sink for RelayIo { match *self { Self::Derp(ref mut framed) => framed.start_send_unpin(item), Self::Ws(ref mut ws) => ws - .start_send_unpin(Self::frame_to_message(item)?) - .map_err(Self::tung_to_io_err), + .start_send_unpin(item.into_ws_message()?) + .map_err(tung_to_io_err), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { Self::Derp(ref mut framed) => framed.poll_flush_unpin(cx), - Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(Self::tung_to_io_err), + Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(tung_to_io_err), } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { Self::Derp(ref mut framed) => framed.poll_close_unpin(cx), - Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(Self::tung_to_io_err), + Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(tung_to_io_err), } } } @@ -255,7 +229,7 @@ impl Stream for RelayIo { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { Self::Derp(ref mut framed) => framed.poll_next_unpin(cx), - Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Self::message_to_frame), + Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Frame::from_ws_message), } } } @@ -686,14 +660,10 @@ mod tests { MaybeTlsStreamWriter::Mem(client_writer), fastwebsockets::Role::Client, ); + let (client_reader, client_writer) = todo!(); // TODO(matheus23) fix tests. Probably just use relay protocol here ( server, - ClientBuilder::new( - secret_key, - "127.0.0.1:0".parse().unwrap(), - client_reader, - client_writer, - ), + ClientBuilder::new(secret_key, None, client_reader, client_writer), ) }