diff --git a/Cargo.lock b/Cargo.lock index 8de7a135b5..c8cf5e2442 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2832,10 +2832,13 @@ dependencies = [ "tokio", "tokio-rustls 0.24.1", "tokio-rustls-acme", + "tokio-tungstenite", + "tokio-tungstenite-wasm", "tokio-util", "toml", "tracing", "tracing-subscriber", + "tungstenite", "url", "watchable", "webpki-roots 0.25.4", @@ -5737,6 +5740,36 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + +[[package]] +name = "tokio-tungstenite-wasm" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e57a65894797a018b28345fa298a00c450a574aa9671e50b18218a6292a55ac" +dependencies = [ + "futures-channel", + "futures-util", + "http 1.1.0", + "httparse", + "js-sys", + "thiserror", + "tokio", + "tokio-tungstenite", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -5960,6 +5993,25 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -6070,6 +6122,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.1" diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index 9901cdf2e4..ec54e0ef8c 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -10,7 +10,7 @@ use iroh_net::key::SecretKey; use crate::{ get::{db::BlobId, progress::TransferState}, - util::progress::{FlumeProgressSender, IdGenerator, ProgressSender}, + util::progress::{FlumeProgressSender, IdGenerator}, }; use super::*; diff --git a/iroh-blobs/src/downloader/test/dialer.rs b/iroh-blobs/src/downloader/test/dialer.rs index 89a1af69b2..4d087145fb 100644 --- a/iroh-blobs/src/downloader/test/dialer.rs +++ b/iroh-blobs/src/downloader/test/dialer.rs @@ -1,9 +1,6 @@ //! Implementation of [`super::Dialer`] used for testing. -use std::{ - collections::HashSet, - task::{Context, Poll}, -}; +use std::task::{Context, Poll}; use parking_lot::RwLock; diff --git a/iroh-blobs/src/downloader/test/getter.rs b/iroh-blobs/src/downloader/test/getter.rs index 378d26579e..397f1134f1 100644 --- a/iroh-blobs/src/downloader/test/getter.rs +++ b/iroh-blobs/src/downloader/test/getter.rs @@ -2,7 +2,6 @@ use futures_lite::{future::Boxed as BoxFuture, FutureExt}; use parking_lot::RwLock; -use std::{sync::Arc, time::Duration}; use super::*; diff --git a/iroh-blobs/src/store/bao_file.rs b/iroh-blobs/src/store/bao_file.rs index 1dc2f72a36..388b85ac60 100644 --- a/iroh-blobs/src/store/bao_file.rs +++ b/iroh-blobs/src/store/bao_file.rs @@ -732,7 +732,6 @@ pub mod test_support { BlockSize, ChunkRanges, }; use futures_lite::{Stream, StreamExt}; - use iroh_base::hash::Hash; use iroh_io::AsyncStreamReader; use rand::RngCore; use range_collections::RangeSet2; diff --git a/iroh-cli/src/commands/doctor.rs b/iroh-cli/src/commands/doctor.rs index a28f749cf6..926bfc52a6 100644 --- a/iroh-cli/src/commands/doctor.rs +++ b/iroh-cli/src/commands/doctor.rs @@ -44,7 +44,6 @@ use iroh::{ }; use portable_atomic::AtomicU64; use postcard::experimental::max_size::MaxSize; -use ratatui::backend::Backend; use serde::{Deserialize, Serialize}; use tokio::{io::AsyncWriteExt, sync}; diff --git a/iroh-net/Cargo.toml b/iroh-net/Cargo.toml index 6481dbba27..f6d636defc 100644 --- a/iroh-net/Cargo.toml +++ b/iroh-net/Cargo.toml @@ -68,8 +68,11 @@ time = "0.3.20" tokio = { version = "1", features = ["io-util", "macros", "sync", "rt", "net", "fs", "io-std", "signal", "process"] } tokio-rustls = { version = "0.24" } tokio-rustls-acme = { version = "0.3" } +tokio-tungstenite = "0.21" +tokio-tungstenite-wasm = "0.3" tokio-util = { version = "0.7", features = ["io-util", "io", "codec"] } tracing = "0.1" +tungstenite = "0.21" url = { version = "2.4", features = ["serde"] } watchable = "1.1.2" webpki = { package = "rustls-webpki", version = "0.101.4", features = ["std"] } diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 5e4f1f0afe..d660070ffd 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -2747,7 +2747,6 @@ impl NetInfo { #[cfg(test)] mod tests { use anyhow::Context; - use futures_lite::StreamExt; use iroh_test::CallOnDrop; use rand::RngCore; diff --git a/iroh-net/src/relay.rs b/iroh-net/src/relay.rs index 88213f0635..746a8b607f 100644 --- a/iroh-net/src/relay.rs +++ b/iroh-net/src/relay.rs @@ -13,7 +13,7 @@ pub(crate) mod client; pub(crate) mod client_conn; pub(crate) mod clients; -mod codec; +pub(crate) mod codec; pub mod http; pub mod iroh_relay; mod map; diff --git a/iroh-net/src/relay/client.rs b/iroh-net/src/relay/client.rs index bf0e069bfc..173919d906 100644 --- a/iroh-net/src/relay/client.rs +++ b/iroh-net/src/relay/client.rs @@ -1,15 +1,18 @@ //! based on tailscale/derp/derp_client.go use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use anyhow::{anyhow, bail, ensure, Result}; use bytes::Bytes; -use futures_lite::StreamExt; +use futures_lite::Stream; use futures_sink::Sink; -use futures_util::sink::SinkExt; -use tokio::io::AsyncWrite; +use futures_util::stream::{SplitSink, SplitStream, StreamExt}; +use futures_util::SinkExt; use tokio::sync::mpsc; +use tokio_tungstenite_wasm::WebSocketStream; use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::{debug, info_span, trace, Instrument}; @@ -64,12 +67,12 @@ impl ClientReceiver { } } -type RelayReader = FramedRead; - #[derive(derive_more::Debug)] pub struct InnerClient { - // our local addrs - local_addr: SocketAddr, + /// Our local address, if known. + /// + /// Is `None` in tests or when using websockets (because we don't control connection establishment in browsers). + local_addr: Option, /// Channel on which to communicate to the server. The associated [`mpsc::Receiver`] will close /// if there is ever an error writing to the server. writer_channel: mpsc::Sender, @@ -123,8 +126,10 @@ impl Client { } /// The local address that the [`Client`] is listening on. - pub fn local_addr(&self) -> Result { - Ok(self.inner.local_addr) + /// + /// `None`, when run in a testing environment or when using websockets. + pub fn local_addr(&self) -> Option { + self.inner.local_addr } /// Whether or not this [`Client`] is closed. @@ -209,13 +214,13 @@ enum ClientWriterMessage { /// /// Shutsdown when you send a [`ClientWriterMessage::Shutdown`], or if there is an error writing to /// the server. -struct ClientWriter { +struct ClientWriter { recv_msgs: mpsc::Receiver, - writer: FramedWrite, + writer: ConnWriter, rate_limiter: Option, } -impl ClientWriter { +impl ClientWriter { async fn run(mut self) -> Result<()> { while let Some(msg) = self.recv_msgs.recv().await { match msg { @@ -244,25 +249,100 @@ impl ClientWriter { } } -/// The Builder returns a [`Client`] starts a [`ClientWriter`] run task. +/// The Builder returns a [`Client`] and a started [`ClientWriter`] run task. pub struct ClientBuilder { secret_key: SecretKey, - reader: RelayReader, - writer: FramedWrite, - local_addr: SocketAddr, + reader: ConnReader, + writer: ConnWriter, + local_addr: Option, +} + +pub(crate) enum ConnReader { + Derp(FramedRead), + Ws(SplitStream), +} + +pub(crate) enum ConnWriter { + Derp(FramedWrite), + Ws(SplitSink), +} + +fn tung_wasm_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error { + match e { + tokio_tungstenite_wasm::Error::Io(io_err) => io_err, + _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()), + } +} + +impl Stream for ConnReader { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut ws) => Pin::new(ws).poll_next(cx), + Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) { + Poll::Ready(Some(Ok(tokio_tungstenite_wasm::Message::Binary(vec)))) => { + Poll::Ready(Some(Frame::decode_from_ws_msg(vec))) + } + Poll::Ready(Some(Ok(msg))) => { + tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + Poll::Pending + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, + } + } +} + +impl Sink for ConnWriter { + type Error = std::io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut ws) => Pin::new(ws).poll_ready(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_wasm_to_io_err), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { + match *self { + Self::Derp(ref mut ws) => Pin::new(ws).start_send(item), + Self::Ws(ref mut ws) => Pin::new(ws) + .start_send(tokio_tungstenite_wasm::Message::binary( + item.encode_for_ws_msg(), + )) + .map_err(tung_wasm_to_io_err), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut ws) => Pin::new(ws).poll_flush(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_wasm_to_io_err), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut ws) => Pin::new(ws).poll_close(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_wasm_to_io_err), + } + } } impl ClientBuilder { pub fn new( secret_key: SecretKey, - local_addr: SocketAddr, - reader: MaybeTlsStreamReader, - writer: MaybeTlsStreamWriter, + local_addr: Option, + reader: ConnReader, + writer: ConnWriter, ) -> Self { Self { secret_key, - reader: FramedRead::new(reader, DerpCodec), - writer: FramedWrite::new(writer, DerpCodec), + reader, + writer, local_addr, } } diff --git a/iroh-net/src/relay/client_conn.rs b/iroh-net/src/relay/client_conn.rs index dd876c0ad9..05171937d1 100644 --- a/iroh-net/src/relay/client_conn.rs +++ b/iroh-net/src/relay/client_conn.rs @@ -7,7 +7,6 @@ use bytes::Bytes; use futures_lite::StreamExt; use futures_util::SinkExt; use tokio::sync::mpsc; -use tokio_util::codec::Framed; use tokio_util::sync::CancellationToken; use tracing::{trace, Instrument}; @@ -16,8 +15,8 @@ use crate::{disco::looks_like_disco_wrapper, key::PublicKey}; use iroh_metrics::{inc, inc_by}; -use super::codec::{DerpCodec, Frame}; -use super::server::MaybeTlsStream; +use super::codec::Frame; +use super::server::RelayIo; use super::{ codec::{write_frame, KEEP_ALIVE}, metrics::Metrics, @@ -73,7 +72,7 @@ pub(crate) struct ClientChannels { pub struct ClientConnBuilder { pub(crate) key: PublicKey, pub(crate) conn_num: usize, - pub(crate) io: Framed, + pub(crate) io: RelayIo, pub(crate) write_timeout: Option, pub(crate) channel_capacity: usize, pub(crate) server_channel: mpsc::Sender, @@ -102,7 +101,7 @@ impl ClientConnManager { pub fn new( key: PublicKey, conn_num: usize, - io: Framed, + io: RelayIo, write_timeout: Option, channel_capacity: usize, server_channel: mpsc::Sender, @@ -203,7 +202,7 @@ impl ClientConnManager { #[derive(Debug)] pub(crate) struct ClientConnIo { /// Io to talk to the client - io: Framed, + io: RelayIo, /// Max time we wait to complete a write to the client timeout: Option, /// Packets queued to send to the client @@ -453,11 +452,13 @@ impl ClientConnIo { #[cfg(test)] mod tests { use crate::key::SecretKey; - use crate::relay::codec::{recv_frame, FrameType}; + use crate::relay::codec::{recv_frame, DerpCodec, FrameType}; + use crate::relay::MaybeTlsStreamServer as MaybeTlsStream; use super::*; use anyhow::bail; + use tokio_util::codec::Framed; #[tokio::test] async fn test_client_conn_io_basic() -> Result<()> { @@ -472,7 +473,7 @@ mod tests { let (server_channel_s, mut server_channel_r) = mpsc::channel(10); let conn_io = ClientConnIo { - io: Framed::new(MaybeTlsStream::Test(io), DerpCodec), + io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), timeout: None, send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, @@ -607,7 +608,7 @@ mod tests { println!("-- create client conn"); let conn_io = ClientConnIo { - io: Framed::new(MaybeTlsStream::Test(io), DerpCodec), + io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), timeout: None, send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, diff --git a/iroh-net/src/relay/clients.rs b/iroh-net/src/relay/clients.rs index 7f510ababc..32c76a01cf 100644 --- a/iroh-net/src/relay/clients.rs +++ b/iroh-net/src/relay/clients.rs @@ -260,7 +260,11 @@ mod tests { use crate::{ key::SecretKey, - relay::codec::{recv_frame, DerpCodec, Frame, FrameType}, + relay::{ + codec::{recv_frame, DerpCodec, Frame, FrameType}, + server::RelayIo, + MaybeTlsStreamServer as MaybeTlsStream, + }, }; use anyhow::Result; @@ -278,7 +282,7 @@ mod tests { ClientConnBuilder { key, conn_num, - io: Framed::new(crate::relay::server::MaybeTlsStream::Test(io), DerpCodec), + io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), write_timeout: None, channel_capacity: 10, server_channel, diff --git a/iroh-net/src/relay/codec.rs b/iroh-net/src/relay/codec.rs index 5eef327e45..a63d88ab20 100644 --- a/iroh-net/src/relay/codec.rs +++ b/iroh-net/src/relay/codec.rs @@ -112,7 +112,7 @@ impl std::fmt::Display for FrameType { } /// Writes complete frame, errors if it is unable to write within the given `timeout`. -/// Ignores the timeout if `timeout.is_zero()` +/// Ignores the timeout if `None` /// /// Does not flush. pub(super) async fn write_frame + Unpin>( @@ -264,8 +264,31 @@ impl Frame { } } + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. + pub(crate) fn decode_from_ws_msg(vec: Vec) -> anyhow::Result { + if vec.is_empty() { + bail!("error parsing relay::codec::Frame: too few bytes (0)"); + } + let bytes = Bytes::from(vec); + let typ = FrameType::from(bytes[0]); + let frame = Self::from_bytes(typ, bytes.slice(1..))?; + Ok(frame) + } + + /// Encodes this frame for sending over websockets. + /// + /// Specifically meant for being put into a binary websocket message frame. + pub(crate) fn encode_for_ws_msg(self) -> Vec { + let mut bytes = Vec::new(); + bytes.put_u8(self.typ().into()); + self.write_to(&mut bytes); + bytes + } + /// Writes it self to the given buffer. - fn write_to(&self, dst: &mut BytesMut) { + fn write_to(&self, dst: &mut impl BufMut) { match self { Frame::ClientInfo { client_public_key, @@ -561,6 +584,94 @@ mod tests { assert_eq!(client_info, got_client_info); Ok(()) } + + #[test] + fn test_frame_snapshot() -> anyhow::Result<()> { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + let client_info = ClientInfo { + version: PROTOCOL_VERSION, + }; + let message = postcard::to_stdvec(&client_info)?; + let signature = client_key.sign(&message); + + let frames = vec![ + ( + Frame::ClientInfo { + client_public_key: client_key.public(), + message: Bytes::from(message), + signature, + }, + "02 52 45 4c 41 59 f0 9f 94 91 19 7f 6b 23 e1 6c + 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 + 03 34 03 9b fa 8b 3d 36 8d 61 88 e7 7b 22 f2 92 + ab 37 43 5d a8 de 0b c8 cb 84 e2 88 f4 e7 3b 35 + 82 a5 27 31 e9 ff 98 65 46 5c 87 e0 5e 8d 42 7d + f4 22 bb 6e 85 e1 c0 5f 6f 74 98 37 ba a4 a5 c7 + eb a3 23 0d 77 56 99 10 43 0e 03", + ), + ( + Frame::Health { + problem: "Hello? Yes this is dog.".into(), + }, + "0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + 20 69 73 20 64 6f 67 2e", + ), + (Frame::KeepAlive, "06"), + (Frame::NotePreferred { preferred: true }, "07 01"), + ( + Frame::PeerGone { + peer: client_key.public(), + }, + "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61", + ), + ( + Frame::Ping { data: [42u8; 8] }, + "0c 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + Frame::Pong { data: [42u8; 8] }, + "0d 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + Frame::RecvPacket { + src_key: client_key.public(), + content: "Hello World!".into(), + }, + "05 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), + ( + Frame::SendPacket { + dst_key: client_key.public(), + packet: "Goodbye!".into(), + }, + "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 47 6f 6f 64 62 79 65 21", + ), + ( + Frame::Restarting { + reconnect_in: 10, + try_for: 20, + }, + "0f 00 00 00 0a 00 00 00 14", + ), + ]; + + for (frame, expected_hex) in frames { + let bytes = frame.encode_for_ws_msg(); + // To regenerate the hexdumps: + // let hexdump = iroh_test::hexdump::print_hexdump(bytes, []); + // println!("{hexdump}"); + let expected_bytes = iroh_test::hexdump::parse_hexdump(expected_hex)?; + assert_eq!(bytes, expected_bytes); + } + + Ok(()) + } } #[cfg(test)] @@ -664,6 +775,13 @@ mod proptests { prop_assert_eq!(frame, decoded); } + #[test] + fn frame_ws_roundtrip(frame in frame()) { + let encoded = frame.clone().encode_for_ws_msg(); + let decoded = Frame::decode_from_ws_msg(encoded).unwrap(); + prop_assert_eq!(frame, decoded); + } + // Test that typical invalid frames will result in an error #[test] fn broken_frame_handling(frame in frame()) { diff --git a/iroh-net/src/relay/http.rs b/iroh-net/src/relay/http.rs index cd3d7519bf..a61d17b798 100644 --- a/iroh-net/src/relay/http.rs +++ b/iroh-net/src/relay/http.rs @@ -6,9 +6,11 @@ mod server; pub(crate) mod streams; pub use self::client::{Client, ClientBuilder, ClientError, ClientReceiver}; -pub use self::server::{Server, ServerBuilder, ServerHandle, TlsAcceptor, TlsConfig}; +pub use self::server::{Protocol, Server, ServerBuilder, ServerHandle, TlsAcceptor, TlsConfig}; pub(crate) const HTTP_UPGRADE_PROTOCOL: &str = "iroh derp http"; +pub(crate) const WEBSOCKET_UPGRADE_PROTOCOL: &str = "websocket"; +pub(crate) const SUPPORTED_WEBSOCKET_VERSION: &str = "13"; #[cfg(test)] mod tests { diff --git a/iroh-net/src/relay/http/client.rs b/iroh-net/src/relay/http/client.rs index 39c9302bd4..f4d8957832 100644 --- a/iroh-net/src/relay/http/client.rs +++ b/iroh-net/src/relay/http/client.rs @@ -8,6 +8,7 @@ use std::time::Duration; use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use bytes::Bytes; use futures_lite::future::Boxed as BoxFuture; +use futures_util::StreamExt; use http_body_util::Empty; use hyper::body::Incoming; use hyper::header::UPGRADE; @@ -21,11 +22,14 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinSet; use tokio::time::Instant; +use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::{debug, error, info_span, trace, warn, Instrument}; use url::Url; use crate::dns::{DnsResolver, ResolverExt}; use crate::key::{PublicKey, SecretKey}; +use crate::relay::client::{ConnReader, ConnWriter}; +use crate::relay::codec::DerpCodec; use crate::relay::http::streams::{downcast_upgrade, MaybeTlsStream}; use crate::relay::RelayUrl; use crate::relay::{ @@ -35,6 +39,7 @@ use crate::relay::{ use crate::util::chain; use crate::util::AbortingJoinHandle; +use super::server::Protocol; use super::streams::ProxyStream; const DIAL_NODE_TIMEOUT: Duration = Duration::from_millis(1500); @@ -120,6 +125,9 @@ pub enum ClientError { /// The inner actor is gone, likely means things are shutdown. #[error("actor gone")] ActorGone, + /// An error related to websockets, either errors with parsing ws messages or the handshake + #[error("websocket error: {0}")] + WebsocketError(#[from] tokio_tungstenite_wasm::Error), } /// An HTTP Relay client. @@ -580,6 +588,51 @@ impl Actor { } async fn connect_0(&self) -> Result<(RelayClient, RelayClientReceiver), ClientError> { + // We determine which protocol to use for relays via the URL scheme: ws(s) vs. http(s) + let protocol = Protocol::from_url_scheme(&self.url); + + let (reader, writer, local_addr) = match &protocol { + Protocol::Websocket => { + let (reader, writer) = self.connect_ws().await?; + let local_addr = None; + (reader, writer, local_addr) + } + Protocol::Relay => { + let (reader, writer, local_addr) = self.connect_derp().await?; + (reader, writer, Some(local_addr)) + } + }; + + let (relay_client, receiver) = + RelayClientBuilder::new(self.secret_key.clone(), local_addr, reader, writer) + .build() + .await + .map_err(|e| ClientError::Build(e.to_string()))?; + + if self.is_preferred && relay_client.note_preferred(true).await.is_err() { + relay_client.close().await; + return Err(ClientError::Send); + } + + trace!("connect_0 done"); + Ok((relay_client, receiver)) + } + + async fn connect_ws(&self) -> Result<(ConnReader, ConnWriter), ClientError> { + let mut dial_url = (*self.url).clone(); + dial_url.set_path("/derp"); + + debug!(%dial_url, "Dialing relay by websocket"); + + let (writer, reader) = tokio_tungstenite_wasm::connect(dial_url).await?.split(); + + let reader = ConnReader::Ws(reader); + let writer = ConnWriter::Ws(writer); + + Ok((reader, writer)) + } + + async fn connect_derp(&self) -> Result<(ConnReader, ConnWriter, SocketAddr), ClientError> { let tcp_stream = self.dial_url().await?; let local_addr = tcp_stream @@ -588,7 +641,7 @@ impl Actor { debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); - let response = if self.use_https() { + let response = if self.use_tls() { debug!("Starting TLS handshake"); let hostname = self .tls_servername() @@ -625,19 +678,10 @@ impl Actor { let (reader, writer) = downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?; - let (relay_client, receiver) = - RelayClientBuilder::new(self.secret_key.clone(), local_addr, reader, writer) - .build() - .await - .map_err(|e| ClientError::Build(e.to_string()))?; - - if self.is_preferred && relay_client.note_preferred(true).await.is_err() { - relay_client.close().await; - return Err(ClientError::Send); - } + let reader = ConnReader::Derp(FramedRead::new(reader, DerpCodec)); + let writer = ConnWriter::Derp(FramedWrite::new(writer, DerpCodec)); - trace!("connect_0 done"); - Ok((relay_client, receiver)) + Ok((reader, writer, local_addr)) } /// Sends the HTTP upgrade request to the relay server. @@ -663,7 +707,7 @@ impl Actor { debug!("Sending upgrade request"); let req = Request::builder() .uri("/derp") - .header(UPGRADE, super::HTTP_UPGRADE_PROTOCOL) + .header(UPGRADE, Protocol::Relay.upgrade_header()) .body(http_body_util::Empty::::new())?; request_sender.send_request(req).await.map_err(From::from) } @@ -695,12 +739,10 @@ impl Actor { return None; } if let Some((ref client, _)) = self.relay_client { - match client.local_addr() { - Ok(addr) => return Some(addr), - _ => return None, - } + client.local_addr() + } else { + None } - None } async fn ping(&mut self, s: oneshot::Sender>) { @@ -782,12 +824,14 @@ impl Actor { .and_then(|s| rustls::ServerName::try_from(s).ok()) } - fn use_https(&self) -> bool { - // only disable https if we are explicitly dialing a http url - if self.url.scheme() == "http" { - return false; + fn use_tls(&self) -> bool { + // only disable tls if we are explicitly dialing a http url + #[allow(clippy::match_like_matches_macro)] + match self.url.scheme() { + "http" => false, + "ws" => false, + _ => true, } - true } async fn dial_url(&self) -> Result { diff --git a/iroh-net/src/relay/http/server.rs b/iroh-net/src/relay/http/server.rs index eaf6ffd70a..f21cb75460 100644 --- a/iroh-net/src/relay/http/server.rs +++ b/iroh-net/src/relay/http/server.rs @@ -8,20 +8,25 @@ use anyhow::{bail, ensure, Context as _, Result}; use bytes::Bytes; use derive_more::Debug; use futures_lite::FutureExt; +use http::header::CONNECTION; use http::response::Builder as ResponseBuilder; use hyper::body::Incoming; use hyper::header::{HeaderValue, UPGRADE}; use hyper::service::Service; use hyper::upgrade::Upgraded; use hyper::{HeaderMap, Method, Request, Response, StatusCode}; +use iroh_base::node_addr::RelayUrl; use tokio::net::{TcpListener, TcpStream}; use tokio::task::JoinHandle; use tokio_rustls_acme::AcmeAcceptor; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, info_span, Instrument}; +use tracing::{debug, debug_span, error, info, info_span, warn, Instrument}; +use tungstenite::handshake::derive_accept_key; use crate::key::SecretKey; -use crate::relay::http::HTTP_UPGRADE_PROTOCOL; +use crate::relay::http::{ + HTTP_UPGRADE_PROTOCOL, SUPPORTED_WEBSOCKET_VERSION, WEBSOCKET_UPGRADE_PROTOCOL, +}; use crate::relay::server::{ClientConnHandler, MaybeTlsStream}; use crate::relay::MaybeTlsStreamServer; @@ -54,20 +59,68 @@ fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes)> { } } -/// The server HTTP handler to do HTTP upgrades -async fn relay_connection_handler( - conn_handler: &ClientConnHandler, - upgraded: Upgraded, -) -> Result<()> { - debug!("relay_connection upgraded"); - let (io, read_buf) = downcast_upgrade(upgraded)?; - ensure!( - read_buf.is_empty(), - "can not deal with buffered data yet: {:?}", - read_buf - ); - - conn_handler.accept(io).await +/// The HTTP upgrade protocol used for relaying. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Protocol { + /// Relays over the custom relaying protocol with a custom HTTP upgrade header. + Relay, + /// Relays over websockets. + /// + /// Originally introduced to support browser connections. + Websocket, +} + +impl Protocol { + /// The HTTP upgrade header used or expected + pub const fn upgrade_header(&self) -> &'static str { + match self { + Protocol::Relay => HTTP_UPGRADE_PROTOCOL, + Protocol::Websocket => WEBSOCKET_UPGRADE_PROTOCOL, + } + } + + /// Determines which protocol to use depending on a URL. + /// + /// `ws(s)` parses as websockets, `http(s)` parses to the custom relay protocol. + pub fn from_url_scheme(url: &RelayUrl) -> Self { + match url.scheme() { + "ws" => Protocol::Websocket, + "wss" => Protocol::Websocket, + "http" => Protocol::Relay, + "https" => Protocol::Relay, + // We default to relay in case of weird URLs. + _ => Protocol::Relay, + } + } + + /// Tries to match the value of an HTTP upgrade header to figure out which protocol should be initiated. + pub fn parse_header(header: &HeaderValue) -> Option { + let header_bytes = header.as_bytes(); + if header_bytes == Protocol::Relay.upgrade_header().as_bytes() { + Some(Protocol::Relay) + } else if header_bytes == Protocol::Websocket.upgrade_header().as_bytes() { + Some(Protocol::Websocket) + } else { + None + } + } + + /// The server HTTP handler to do HTTP upgrades + async fn relay_connection_handler( + self, + conn_handler: &ClientConnHandler, + upgraded: Upgraded, + ) -> Result<()> { + debug!(protocol = ?self, "relay_connection upgraded"); + let (io, read_buf) = downcast_upgrade(upgraded)?; + ensure!( + read_buf.is_empty(), + "can not deal with buffered data yet: {:?}", + read_buf + ); + + conn_handler.accept(self, io).await + } } /// The Relay HTTP server. @@ -338,7 +391,7 @@ impl ServerState { // we will use this cancel token to stop the infinite loop in the `listener.accept() task` let cancel_server_loop = CancellationToken::new(); let addr = listener.local_addr()?; - let http_str = tls_config.as_ref().map_or("HTTP", |_| "HTTPS"); + let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS"); info!("[{http_str}] relay: serving on {addr}"); let cancel = cancel_server_loop.clone(); let task = tokio::task::spawn(async move { @@ -410,13 +463,48 @@ impl Service> for ClientConnHandler { async move { { - let mut res = builder.body(body_empty()).expect("valid body"); - // Send a 400 to any request that doesn't have an `Upgrade` header. - if !req.headers().contains_key(UPGRADE) { - *res.status_mut() = StatusCode::BAD_REQUEST; - return Ok(res); - } + let Some(protocol) = req.headers().get(UPGRADE).and_then(Protocol::parse_header) + else { + return Ok(builder + .status(StatusCode::BAD_REQUEST) + .body(body_empty()) + .expect("valid body")); + }; + + let websocket_headers = if protocol == Protocol::Websocket { + let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else { + warn!("missing header Sec-WebSocket-Key for websocket relay protocol"); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + .body(body_empty()) + .expect("valid body")); + }; + + let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else { + warn!("missing header Sec-WebSocket-Version for websocket relay protocol"); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + .body(body_empty()) + .expect("valid body")); + }; + + if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() { + warn!("invalid header Sec-WebSocket-Version: {:?}", version); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + // It's convention to send back the version(s) we *do* support + .header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION) + .body(body_empty()) + .expect("valid body")); + } + + Some((key, version)) + } else { + None + }; + + debug!("upgrading protocol: {:?}", protocol); // Setup a future that will eventually receive the upgraded // connection and talk a new protocol, and spawn the future @@ -429,31 +517,40 @@ impl Service> for ClientConnHandler { async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { - if let Err(e) = - relay_connection_handler(&closure_conn_handler, upgraded).await + if let Err(e) = protocol + .relay_connection_handler(&closure_conn_handler, upgraded) + .await { - tracing::warn!( - "upgrade to \"{HTTP_UPGRADE_PROTOCOL}\": io error: {:?}", - e + warn!( + "upgrade to \"{}\": io error: {:?}", + e, + protocol.upgrade_header() ); } else { - tracing::debug!( - "upgrade to \"{HTTP_UPGRADE_PROTOCOL}\" success" - ); + debug!("upgrade to \"{}\" success", protocol.upgrade_header()); }; } - Err(e) => tracing::warn!("upgrade error: {:?}", e), + Err(e) => warn!("upgrade error: {:?}", e), } } - .instrument(tracing::debug_span!("handler")), + .instrument(debug_span!("handler")), ); // Now return a 101 Response saying we agree to the upgrade to the // HTTP_UPGRADE_PROTOCOL - *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - res.headers_mut() - .insert(UPGRADE, HeaderValue::from_static(HTTP_UPGRADE_PROTOCOL)); - Ok(res) + builder = builder + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(UPGRADE, HeaderValue::from_static(protocol.upgrade_header())); + + if let Some((key, _version)) = websocket_headers { + Ok(builder + .header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes())) + .header(CONNECTION, "upgrade") + .body(body_full("switching to websocket protocol")) + .expect("valid body")) + } else { + Ok(builder.body(body_empty()).expect("valid body")) + } } } .boxed() diff --git a/iroh-net/src/relay/iroh_relay.rs b/iroh-net/src/relay/iroh_relay.rs index 928cdbaa8c..253b7077f0 100644 --- a/iroh-net/src/relay/iroh_relay.rs +++ b/iroh-net/src/relay/iroh_relay.rs @@ -801,7 +801,7 @@ mod tests { } #[tokio::test] - async fn test_relay_clients() { + async fn test_relay_clients_both_derp() { let _guard = iroh_test::logging::setup(); let server = Server::spawn(ServerConfig::<(), ()> { relay: Some(RelayConfig { @@ -876,6 +876,164 @@ mod tests { } } + #[tokio::test] + async fn test_relay_clients_both_websockets() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: None, + }) + .await + .unwrap(); + // NOTE: Using `ws://` URL scheme to trigger websockets. + let relay_url = format!("ws://{}", server.http_addr().unwrap()); + let relay_url: RelayUrl = relay_url.parse().unwrap(); + + // set up client a + let a_secret_key = SecretKey::generate(); + let a_key = a_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_a, mut client_a_receiver) = + ClientBuilder::new(relay_url.clone()).build(a_secret_key, resolver); + let connect_client = client_a.clone(); + + // give the relay server some time to accept connections + if let Err(err) = tokio::time::timeout(Duration::from_secs(10), async move { + loop { + match connect_client.connect().await { + Ok(_) => break, + Err(err) => { + warn!("client unable to connect to relay server: {err:#}"); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }) + .await + { + panic!("error connecting to relay server: {err:#}"); + } + + // set up client b + let b_secret_key = SecretKey::generate(); + let b_key = b_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_b, mut client_b_receiver) = + ClientBuilder::new(relay_url.clone()).build(b_secret_key, resolver); + client_b.connect().await.unwrap(); + + // send message from a to b + let msg = Bytes::from("hello, b"); + client_a.send(b_key, msg.clone()).await.unwrap(); + + let (res, _) = client_b_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(a_key, source); + assert_eq!(msg, data); + } else { + panic!("client_b received unexpected message {res:?}"); + } + + // send message from b to a + let msg = Bytes::from("howdy, a"); + client_b.send(a_key, msg.clone()).await.unwrap(); + + let (res, _) = client_a_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(b_key, source); + assert_eq!(msg, data); + } else { + panic!("client_a received unexpected message {res:?}"); + } + } + + #[tokio::test] + async fn test_relay_clients_websocket_and_derp() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: None, + }) + .await + .unwrap(); + + let derp_relay_url = format!("http://{}", server.http_addr().unwrap()); + let derp_relay_url: RelayUrl = derp_relay_url.parse().unwrap(); + + // NOTE: Using `ws://` URL scheme to trigger websockets. + let ws_relay_url = format!("ws://{}", server.http_addr().unwrap()); + let ws_relay_url: RelayUrl = ws_relay_url.parse().unwrap(); + + // set up client a + let a_secret_key = SecretKey::generate(); + let a_key = a_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_a, mut client_a_receiver) = + ClientBuilder::new(derp_relay_url.clone()).build(a_secret_key, resolver); + let connect_client = client_a.clone(); + + // give the relay server some time to accept connections + if let Err(err) = tokio::time::timeout(Duration::from_secs(10), async move { + loop { + match connect_client.connect().await { + Ok(_) => break, + Err(err) => { + warn!("client unable to connect to relay server: {err:#}"); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }) + .await + { + panic!("error connecting to relay server: {err:#}"); + } + + // set up client b + let b_secret_key = SecretKey::generate(); + let b_key = b_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_b, mut client_b_receiver) = + ClientBuilder::new(ws_relay_url.clone()).build(b_secret_key, resolver); + client_b.connect().await.unwrap(); + + // send message from a to b + let msg = Bytes::from("hello, b"); + client_a.send(b_key, msg.clone()).await.unwrap(); + + let (res, _) = client_b_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(a_key, source); + assert_eq!(msg, data); + } else { + panic!("client_b received unexpected message {res:?}"); + } + + // send message from b to a + let msg = Bytes::from("howdy, a"); + client_b.send(a_key, msg.clone()).await.unwrap(); + + let (res, _) = client_a_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(b_key, source); + assert_eq!(msg, data); + } else { + panic!("client_a received unexpected message {res:?}"); + } + } + #[tokio::test] async fn test_stun() { let _guard = iroh_test::logging::setup(); diff --git a/iroh-net/src/relay/metrics.rs b/iroh-net/src/relay/metrics.rs index 3100d857f7..923d375b4a 100644 --- a/iroh-net/src/relay/metrics.rs +++ b/iroh-net/src/relay/metrics.rs @@ -60,6 +60,11 @@ pub struct Metrics { pub accepts: Counter, /// Number of connections we have removed because of an error pub disconnects: Counter, + + /// Number of accepted websocket connections + pub websocket_accepts: Counter, + /// Number of accepted 'iroh derp http' connection upgrades + pub derp_accepts: Counter, // TODO: enable when we can have multiple connections for one node id // pub duplicate_client_keys: Counter, // pub duplicate_client_conns: Counter, @@ -115,6 +120,9 @@ impl Default for Metrics { accepts: Counter::new("Number of times this server has accepted a connection."), disconnects: Counter::new("Number of clients that have then disconnected."), + + websocket_accepts: Counter::new("Number of accepted websocket connections"), + derp_accepts: Counter::new("Number of accepted 'iroh derp http' connection upgrades"), // TODO: enable when we can have multiple connections for one node id // pub duplicate_client_keys: Counter::new("Number of duplicate client keys."), // pub duplicate_client_conns: Counter::new("Number of duplicate client connections."), diff --git a/iroh-net/src/relay/server.rs b/iroh-net/src/relay/server.rs index 05dbc60ad7..8240b0fb2c 100644 --- a/iroh-net/src/relay/server.rs +++ b/iroh-net/src/relay/server.rs @@ -6,18 +6,24 @@ use std::task::{Context, Poll}; use std::time::Duration; use anyhow::{bail, Context as _, Result}; +use futures_lite::Stream; +use futures_sink::Sink; use hyper::HeaderMap; use iroh_metrics::core::UsageStatsReport; use iroh_metrics::{inc, report_usage_stats}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc; use tokio::task::JoinHandle; +use tokio_tungstenite::WebSocketStream; use tokio_util::codec::Framed; use tokio_util::sync::CancellationToken; use tracing::{info_span, trace, Instrument}; +use tungstenite::protocol::Role; use crate::key::{PublicKey, SecretKey}; +use super::codec::Frame; +use super::http::Protocol; use super::{ client_conn::ClientConnBuilder, clients::Clients, @@ -175,9 +181,18 @@ impl ClientConnHandler { /// and is unable to verify this one, or if there is some issue communicating with the server. /// /// The provided [`AsyncRead`] and [`AsyncWrite`] must be already connected to the connection. - pub async fn accept(&self, io: MaybeTlsStream) -> Result<()> { - let mut io = Framed::new(io, DerpCodec); - trace!("accept: start"); + pub async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<()> { + trace!(?protocol, "accept: start"); + let mut io = match protocol { + Protocol::Relay => { + inc!(Metrics, derp_accepts); + RelayIo::Derp(Framed::new(io, DerpCodec)) + } + Protocol::Websocket => { + inc!(Metrics, websocket_accepts); + RelayIo::Ws(WebSocketStream::from_raw_socket(io, Role::Server, None).await) + } + }; trace!("accept: recv client key"); let (client_key, info) = recv_client_key(&mut io) .await @@ -247,9 +262,9 @@ impl ServerActor { anyhow::bail!("server channel sender closed unexpectedly, closed client connections, and shutting down server loop"); } }; - match msg { + match msg { ServerMessage::SendPacket((key, packet)) => { - tracing::trace!("send packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len()); + tracing::trace!("send packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len()); let src = packet.src; if self.clients.contains_key(&key) { // if this client is in our local network, just try to send the @@ -262,14 +277,13 @@ impl ServerActor { inc!(Metrics, send_packets_dropped); } } - ServerMessage::SendDiscoPacket((key, packet)) => { - tracing::trace!("send disco packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len()); + ServerMessage::SendDiscoPacket((key, packet)) => { + tracing::trace!("send disco packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len()); let src = packet.src; if self.clients.contains_key(&key) { // if this client is in our local network, just try to send the // packet if self.clients.send_disco_packet(&key, packet).is_ok() { - self.clients.record_send(&src, key); } } else { @@ -349,6 +363,75 @@ fn init_meta_cert(server_key: &PublicKey) -> Vec { .expect("fixed allocations") } +#[derive(Debug)] +pub(crate) enum RelayIo { + Derp(Framed), + Ws(WebSocketStream), +} + +fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error { + match e { + tungstenite::Error::Io(io_err) => io_err, + _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()), + } +} + +impl Sink for RelayIo { + type Error = std::io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut framed) => Pin::new(framed).poll_ready(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_to_io_err), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { + match *self { + Self::Derp(ref mut framed) => Pin::new(framed).start_send(item), + Self::Ws(ref mut ws) => Pin::new(ws) + .start_send(tungstenite::Message::Binary(item.encode_for_ws_msg())) + .map_err(tung_to_io_err), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut framed) => Pin::new(framed).poll_flush(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_to_io_err), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut framed) => Pin::new(framed).poll_close(cx), + Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_to_io_err), + } + } +} + +impl Stream for RelayIo { + type Item = anyhow::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Derp(ref mut framed) => Pin::new(framed).poll_next(cx), + Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) { + Poll::Ready(Some(Ok(tungstenite::Message::Binary(vec)))) => { + Poll::Ready(Some(Frame::decode_from_ws_msg(vec))) + } + Poll::Ready(Some(Ok(msg))) => { + tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + Poll::Pending + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, + } + } +} + /// Whether or not the underlying [`tokio::net::TcpStream`] is served over Tls #[derive(Debug)] pub enum MaybeTlsStream { @@ -432,8 +515,8 @@ mod tests { use super::*; use crate::relay::{ - client::ClientBuilder, - codec::{recv_frame, Frame, FrameType}, + client::{ClientBuilder, ConnReader, ConnWriter}, + codec::{recv_frame, FrameType}, http::streams::{MaybeTlsStreamReader, MaybeTlsStreamWriter}, types::ClientInfo, ReceivedMessage, @@ -454,7 +537,7 @@ mod tests { ClientConnBuilder { key, conn_num, - io: Framed::new(MaybeTlsStream::Test(io), DerpCodec), + io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), write_timeout: None, channel_capacity: 10, server_channel, @@ -562,7 +645,9 @@ mod tests { }); // attempt to add the connection to the server - handler.accept(MaybeTlsStream::Test(server_io)).await?; + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(server_io)) + .await?; client_task.await??; // ensure we inform the server to create the client from the connection! @@ -578,16 +663,16 @@ mod tests { fn make_test_client(secret_key: SecretKey) -> (tokio::io::DuplexStream, ClientBuilder) { let (client, server) = tokio::io::duplex(10); let (client_reader, client_writer) = tokio::io::split(client); + let client_reader = MaybeTlsStreamReader::Mem(client_reader); let client_writer = MaybeTlsStreamWriter::Mem(client_writer); + + let client_reader = ConnReader::Derp(FramedRead::new(client_reader, DerpCodec)); + let client_writer = ConnWriter::Derp(FramedWrite::new(client_writer, DerpCodec)); + ( server, - ClientBuilder::new( - secret_key, - "127.0.0.1:0".parse().unwrap(), - client_reader, - client_writer, - ), + ClientBuilder::new(secret_key, None, client_reader, client_writer), ) } @@ -604,8 +689,11 @@ mod tests { let public_key_a = key_a.public(); let (rw_a, client_a_builder) = make_test_client(key_a); let handler = server.client_conn_handler(Default::default()); - let handler_task = - tokio::spawn(async move { handler.accept(MaybeTlsStream::Test(rw_a)).await }); + let handler_task = tokio::spawn(async move { + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) + .await + }); let (client_a, mut client_receiver_a) = client_a_builder.build().await?; handler_task.await??; @@ -614,8 +702,11 @@ mod tests { let public_key_b = key_b.public(); let (rw_b, client_b_builder) = make_test_client(key_b); let handler = server.client_conn_handler(Default::default()); - let handler_task = - tokio::spawn(async move { handler.accept(MaybeTlsStream::Test(rw_b)).await }); + let handler_task = tokio::spawn(async move { + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) + .await + }); let (client_b, mut client_receiver_b) = client_b_builder.build().await?; handler_task.await??; @@ -674,8 +765,11 @@ mod tests { let public_key_a = key_a.public(); let (rw_a, client_a_builder) = make_test_client(key_a); let handler = server.client_conn_handler(Default::default()); - let handler_task = - tokio::spawn(async move { handler.accept(MaybeTlsStream::Test(rw_a)).await }); + let handler_task = tokio::spawn(async move { + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) + .await + }); let (client_a, mut client_receiver_a) = client_a_builder.build().await?; handler_task.await??; @@ -684,8 +778,11 @@ mod tests { let public_key_b = key_b.public(); let (rw_b, client_b_builder) = make_test_client(key_b.clone()); let handler = server.client_conn_handler(Default::default()); - let handler_task = - tokio::spawn(async move { handler.accept(MaybeTlsStream::Test(rw_b)).await }); + let handler_task = tokio::spawn(async move { + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) + .await + }); let (client_b, mut client_receiver_b) = client_b_builder.build().await?; handler_task.await??; @@ -718,8 +815,11 @@ mod tests { // create client b and connect it to the server let (new_rw_b, new_client_b_builder) = make_test_client(key_b); let handler = server.client_conn_handler(Default::default()); - let handler_task = - tokio::spawn(async move { handler.accept(MaybeTlsStream::Test(new_rw_b)).await }); + let handler_task = tokio::spawn(async move { + handler + .accept(Protocol::Relay, MaybeTlsStream::Test(new_rw_b)) + .await + }); let (new_client_b, mut new_client_receiver_b) = new_client_b_builder.build().await?; handler_task.await??; diff --git a/iroh-test/src/logging.rs b/iroh-test/src/logging.rs index 45636e43ae..fba1cfbd7f 100644 --- a/iroh-test/src/logging.rs +++ b/iroh-test/src/logging.rs @@ -25,6 +25,7 @@ use tracing_subscriber::EnvFilter; /// let _guard = iroh_test::logging::setup(); /// assert!(true); /// } +/// ``` #[must_use = "The tracing guard must only be dropped at the end of the test"] pub fn setup() -> tracing::subscriber::DefaultGuard { if let Ok(handle) = tokio::runtime::Handle::try_current() { diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 238577d40c..679f81f4ee 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -461,8 +461,6 @@ async fn handle_connection( #[cfg(test)] mod tests { - use std::time::Duration; - use anyhow::{bail, Context}; use bytes::Bytes; use iroh_base::node_addr::AddrInfoOptions;