diff --git a/iroh-net/src/relay/client.rs b/iroh-net/src/relay/client.rs index 13290d10449..40ae791dd33 100644 --- a/iroh-net/src/relay/client.rs +++ b/iroh-net/src/relay/client.rs @@ -226,19 +226,19 @@ impl + Unpin + 'static> ClientWriter { match msg { ClientWriterMessage::Packet((key, bytes)) => { send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?; - self.writer.flush(); + self.writer.flush().await?; } ClientWriterMessage::Pong(data) => { write_frame(&mut self.writer, Frame::Pong { data }, None).await?; - self.writer.flush(); + self.writer.flush().await?; } ClientWriterMessage::Ping(data) => { write_frame(&mut self.writer, Frame::Ping { data }, None).await?; - self.writer.flush(); + self.writer.flush().await?; } ClientWriterMessage::NotePreferred(preferred) => { write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?; - self.writer.flush(); + self.writer.flush().await?; } ClientWriterMessage::Shutdown => { return Ok(()); @@ -260,11 +260,19 @@ pub struct ClientBuilder { #[derive(derive_more::Debug)] pub(crate) enum RelayConnReader { + Relay( + #[debug("FramedRead")] + FramedRead, + ), Ws(#[debug("SplitStream")] SplitStream), } #[derive(derive_more::Debug)] pub(crate) enum RelayConnWriter { + Relay( + #[debug("FramedWrite")] + FramedWrite, + ), Ws( #[debug("SplitSink")] SplitSink, @@ -280,6 +288,7 @@ impl Stream for RelayConnReader { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { + Self::Relay(ref mut ws) => ws.poll_next_unpin(cx), Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Frame::from_wasm_ws_message), } } @@ -290,12 +299,14 @@ impl Sink for RelayConnWriter { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { + Self::Relay(ref mut ws) => ws.poll_ready_unpin(cx), Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(tung_to_io_err), } } fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { match *self { + Self::Relay(ref mut ws) => ws.start_send_unpin(item), Self::Ws(ref mut ws) => ws .start_send_unpin(item.into_wasm_ws_message()?) .map_err(tung_to_io_err), @@ -304,12 +315,14 @@ impl Sink for RelayConnWriter { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { + Self::Relay(ref mut ws) => ws.poll_flush_unpin(cx), Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(tung_to_io_err), } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { + Self::Relay(ref mut ws) => ws.poll_close_unpin(cx), Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(tung_to_io_err), } } diff --git a/iroh-net/src/relay/http/client.rs b/iroh-net/src/relay/http/client.rs index 4baa0c7fbe7..8cbbd982acd 100644 --- a/iroh-net/src/relay/http/client.rs +++ b/iroh-net/src/relay/http/client.rs @@ -24,12 +24,14 @@ use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinSet; use tokio::time::Instant; use tokio_tungstenite_wasm::WebSocketStream; +use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::{debug, error, info_span, trace, warn, Instrument}; use url::Url; use crate::dns::{DnsResolver, ResolverExt}; use crate::key::{PublicKey, SecretKey}; use crate::relay::client::{RelayConnReader, RelayConnWriter}; +use crate::relay::codec::DerpCodec; use crate::relay::http::streams::{downcast_upgrade, MaybeTlsStream}; use crate::relay::http::WEBSOCKET_UPGRADE_PROTOCOL; use crate::relay::RelayUrl; @@ -40,6 +42,7 @@ use crate::relay::{ use crate::util::chain; use crate::util::AbortingJoinHandle; +use super::server::Protocol; use super::streams::ProxyStream; const DIAL_NODE_TIMEOUT: Duration = Duration::from_millis(1500); @@ -588,62 +591,73 @@ impl Actor { } async fn connect_0(&self) -> Result<(RelayClient, RelayClientReceiver), ClientError> { - // let tcp_stream = self.dial_url().await?; - - // let local_addr = tcp_stream - // .local_addr() - // .map_err(|e| ClientError::NoLocalAddr(e.to_string()))?; - - // debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); - - // let response = if self.use_https() { - // debug!("Starting TLS handshake"); - // let hostname = self - // .tls_servername() - // .ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?; - // let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?; - // debug!("tls_connector connect success"); - // Self::start_upgrade(tls_stream).await? - // } else { - // debug!("Starting handshake"); - // Self::start_upgrade(tcp_stream).await? - // }; - - // if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - // error!( - // "expected status 101 SWITCHING_PROTOCOLS, got: {}", - // response.status() - // ); - // return Err(ClientError::UnexpectedStatusCode( - // hyper::StatusCode::SWITCHING_PROTOCOLS, - // response.status(), - // )); - // } - - // debug!("starting upgrade"); - // let upgraded = match hyper::upgrade::on(response).await { - // Ok(upgraded) => upgraded, - // Err(err) => { - // warn!("upgrade failed: {:#}", err); - // return Err(ClientError::Hyper(err)); - // } - // }; - - // debug!("connection upgraded"); - // let (reader, writer) = - // downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?; - - // let (reader, writer) = - // fastwebsockets::after_handshake_split(reader, writer, fastwebsockets::Role::Client); - - let (writer, reader) = tokio_tungstenite_wasm::connect(self.url.as_str()) - .await? - .split(); - - let reader = RelayConnReader::Ws(reader); - let writer = RelayConnWriter::Ws(writer); - - let local_addr = None; + const PROTOCOL: Protocol = Protocol::Relay; + + let (reader, writer, local_addr) = match &PROTOCOL { + Protocol::Websocket => { + let (writer, reader) = tokio_tungstenite_wasm::connect(self.url.as_str()) + .await? + .split(); + + let reader = RelayConnReader::Ws(reader); + let writer = RelayConnWriter::Ws(writer); + + let local_addr = None; + + (reader, writer, local_addr) + } + Protocol::Relay => { + let tcp_stream = self.dial_url().await?; + + let local_addr = tcp_stream + .local_addr() + .map_err(|e| ClientError::NoLocalAddr(e.to_string()))?; + + debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected"); + + let response = if self.use_https() { + debug!("Starting TLS handshake"); + let hostname = self + .tls_servername() + .ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?; + let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?; + debug!("tls_connector connect success"); + Self::start_upgrade(tls_stream, &PROTOCOL).await? + } else { + debug!("Starting handshake"); + Self::start_upgrade(tcp_stream, &PROTOCOL).await? + }; + + if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { + error!( + "expected status 101 SWITCHING_PROTOCOLS, got: {}", + response.status() + ); + return Err(ClientError::UnexpectedStatusCode( + hyper::StatusCode::SWITCHING_PROTOCOLS, + response.status(), + )); + } + + debug!("starting upgrade"); + let upgraded = match hyper::upgrade::on(response).await { + Ok(upgraded) => upgraded, + Err(err) => { + warn!("upgrade failed: {:#}", err); + return Err(ClientError::Hyper(err)); + } + }; + + debug!("connection upgraded"); + let (reader, writer) = + downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?; + + let reader = RelayConnReader::Relay(FramedRead::new(reader, DerpCodec)); + let writer = RelayConnWriter::Relay(FramedWrite::new(writer, DerpCodec)); + + (reader, writer, Some(local_addr)) + } + }; let (relay_client, receiver) = RelayClientBuilder::new(self.secret_key.clone(), local_addr, reader, writer) @@ -661,7 +675,10 @@ impl Actor { } /// Sends the HTTP upgrade request to the relay server. - async fn start_upgrade(io: T) -> Result, ClientError> + async fn start_upgrade( + io: T, + protocol: &Protocol, + ) -> Result, ClientError> where T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { @@ -680,16 +697,24 @@ impl Actor { } .instrument(info_span!("http-driver")), ); + debug!("Sending upgrade request"); - let req = Request::builder() - .uri("/derp") - .header(UPGRADE, WEBSOCKET_UPGRADE_PROTOCOL) - .header( - "Sec-WebSocket-Key", - fastwebsockets::handshake::generate_key(), - ) - .header("Sec-WebSocket-Version", "13") - .body(http_body_util::Empty::::new())?; + let mut builder = Request::builder().uri("/derp"); + + match protocol { + Protocol::Websocket => { + builder = builder + .header(UPGRADE, protocol.upgrade_header()) + .header( + "Sec-WebSocket-Key", + fastwebsockets::handshake::generate_key(), + ) + .header("Sec-WebSocket-Version", "13"); + } + Protocol::Relay => builder = builder.header(UPGRADE, protocol.upgrade_header()), + } + + let req = builder.body(http_body_util::Empty::::new())?; request_sender.send_request(req).await.map_err(From::from) } diff --git a/iroh-net/src/relay/server.rs b/iroh-net/src/relay/server.rs index 48604f49fcc..875a086cebc 100644 --- a/iroh-net/src/relay/server.rs +++ b/iroh-net/src/relay/server.rs @@ -503,7 +503,7 @@ mod tests { use super::*; use crate::relay::{ - client::ClientBuilder, + client::{ClientBuilder, RelayConnReader, RelayConnWriter}, codec::{recv_frame, Frame, FrameType}, http::{ server::Protocol, @@ -655,12 +655,16 @@ mod tests { fn make_test_client(secret_key: SecretKey) -> (tokio::io::DuplexStream, ClientBuilder) { let (client, server) = tokio::io::duplex(10); let (client_reader, client_writer) = tokio::io::split(client); - let (client_reader, client_writer) = fastwebsockets::after_handshake_split( - MaybeTlsStreamReader::Mem(client_reader), - MaybeTlsStreamWriter::Mem(client_writer), - fastwebsockets::Role::Client, + let (client_reader, client_writer) = ( + RelayConnReader::Relay(FramedRead::new( + MaybeTlsStreamReader::Mem(client_reader), + DerpCodec, + )), + RelayConnWriter::Relay(FramedWrite::new( + MaybeTlsStreamWriter::Mem(client_writer), + DerpCodec, + )), ); - let (client_reader, client_writer) = todo!(); // TODO(matheus23) fix tests. Probably just use relay protocol here ( server, ClientBuilder::new(secret_key, None, client_reader, client_writer), @@ -682,7 +686,7 @@ mod tests { let handler = server.client_conn_handler(Default::default()); let handler_task = tokio::spawn(async move { handler - .accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) .await }); let (client_a, mut client_receiver_a) = client_a_builder.build().await?; @@ -695,7 +699,7 @@ mod tests { let handler = server.client_conn_handler(Default::default()); let handler_task = tokio::spawn(async move { handler - .accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) .await }); let (client_b, mut client_receiver_b) = client_b_builder.build().await?; @@ -758,7 +762,7 @@ mod tests { let handler = server.client_conn_handler(Default::default()); let handler_task = tokio::spawn(async move { handler - .accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) .await }); let (client_a, mut client_receiver_a) = client_a_builder.build().await?; @@ -771,7 +775,7 @@ mod tests { let handler = server.client_conn_handler(Default::default()); let handler_task = tokio::spawn(async move { handler - .accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) + .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) .await }); let (client_b, mut client_receiver_b) = client_b_builder.build().await?; @@ -808,7 +812,7 @@ mod tests { let handler = server.client_conn_handler(Default::default()); let handler_task = tokio::spawn(async move { handler - .accept(Protocol::Websocket, MaybeTlsStream::Test(new_rw_b)) + .accept(Protocol::Relay, MaybeTlsStream::Test(new_rw_b)) .await }); let (new_client_b, mut new_client_receiver_b) = new_client_b_builder.build().await?;