Skip to content

Commit

Permalink
feat: Add back codepath for derp protocol on the client
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus23 committed Jun 20, 2024
1 parent 36b8235 commit d2aca66
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 81 deletions.
21 changes: 17 additions & 4 deletions iroh-net/src/relay/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,19 @@ impl<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> ClientWriter<W> {
match msg {
ClientWriterMessage::Packet((key, bytes)) => {
send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
self.writer.flush();
self.writer.flush().await?;
}
ClientWriterMessage::Pong(data) => {
write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
self.writer.flush();
self.writer.flush().await?;
}
ClientWriterMessage::Ping(data) => {
write_frame(&mut self.writer, Frame::Ping { data }, None).await?;
self.writer.flush();
self.writer.flush().await?;
}
ClientWriterMessage::NotePreferred(preferred) => {
write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?;
self.writer.flush();
self.writer.flush().await?;
}
ClientWriterMessage::Shutdown => {
return Ok(());
Expand All @@ -260,11 +260,19 @@ pub struct ClientBuilder {

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnReader {
Relay(
#[debug("FramedRead<MaybeTlsStreamReader, DerpCodec>")]
FramedRead<MaybeTlsStreamReader, DerpCodec>,
),
Ws(#[debug("SplitStream<WebSocketStream>")] SplitStream<WebSocketStream>),
}

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnWriter {
Relay(
#[debug("FramedWrite<MaybeTlsStreamWriter, DerpCodec>")]
FramedWrite<MaybeTlsStreamWriter, DerpCodec>,
),
Ws(
#[debug("SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>")]
SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>,
Expand All @@ -280,6 +288,7 @@ impl Stream for RelayConnReader {

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::Relay(ref mut ws) => ws.poll_next_unpin(cx),
Self::Ws(ref mut ws) => ws.poll_next_unpin(cx).map(Frame::from_wasm_ws_message),
}
}
Expand All @@ -290,12 +299,14 @@ impl Sink<Frame> for RelayConnWriter {

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => ws.poll_ready_unpin(cx),
Self::Ws(ref mut ws) => ws.poll_ready_unpin(cx).map_err(tung_to_io_err),
}
}

fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
match *self {
Self::Relay(ref mut ws) => ws.start_send_unpin(item),
Self::Ws(ref mut ws) => ws
.start_send_unpin(item.into_wasm_ws_message()?)
.map_err(tung_to_io_err),
Expand All @@ -304,12 +315,14 @@ impl Sink<Frame> for RelayConnWriter {

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => ws.poll_flush_unpin(cx),
Self::Ws(ref mut ws) => ws.poll_flush_unpin(cx).map_err(tung_to_io_err),
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => ws.poll_close_unpin(cx),
Self::Ws(ref mut ws) => ws.poll_close_unpin(cx).map_err(tung_to_io_err),
}
}
Expand Down
157 changes: 91 additions & 66 deletions iroh-net/src/relay/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinSet;
use tokio::time::Instant;
use tokio_tungstenite_wasm::WebSocketStream;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, error, info_span, trace, warn, Instrument};
use url::Url;

use crate::dns::{DnsResolver, ResolverExt};
use crate::key::{PublicKey, SecretKey};
use crate::relay::client::{RelayConnReader, RelayConnWriter};
use crate::relay::codec::DerpCodec;
use crate::relay::http::streams::{downcast_upgrade, MaybeTlsStream};
use crate::relay::http::WEBSOCKET_UPGRADE_PROTOCOL;
use crate::relay::RelayUrl;
Expand All @@ -40,6 +42,7 @@ use crate::relay::{
use crate::util::chain;
use crate::util::AbortingJoinHandle;

use super::server::Protocol;
use super::streams::ProxyStream;

const DIAL_NODE_TIMEOUT: Duration = Duration::from_millis(1500);
Expand Down Expand Up @@ -588,62 +591,73 @@ impl Actor {
}

async fn connect_0(&self) -> Result<(RelayClient, RelayClientReceiver), ClientError> {
// let tcp_stream = self.dial_url().await?;

// let local_addr = tcp_stream
// .local_addr()
// .map_err(|e| ClientError::NoLocalAddr(e.to_string()))?;

// debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected");

// let response = if self.use_https() {
// debug!("Starting TLS handshake");
// let hostname = self
// .tls_servername()
// .ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?;
// let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?;
// debug!("tls_connector connect success");
// Self::start_upgrade(tls_stream).await?
// } else {
// debug!("Starting handshake");
// Self::start_upgrade(tcp_stream).await?
// };

// if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS {
// error!(
// "expected status 101 SWITCHING_PROTOCOLS, got: {}",
// response.status()
// );
// return Err(ClientError::UnexpectedStatusCode(
// hyper::StatusCode::SWITCHING_PROTOCOLS,
// response.status(),
// ));
// }

// debug!("starting upgrade");
// let upgraded = match hyper::upgrade::on(response).await {
// Ok(upgraded) => upgraded,
// Err(err) => {
// warn!("upgrade failed: {:#}", err);
// return Err(ClientError::Hyper(err));
// }
// };

// debug!("connection upgraded");
// let (reader, writer) =
// downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?;

// let (reader, writer) =
// fastwebsockets::after_handshake_split(reader, writer, fastwebsockets::Role::Client);

let (writer, reader) = tokio_tungstenite_wasm::connect(self.url.as_str())
.await?
.split();

let reader = RelayConnReader::Ws(reader);
let writer = RelayConnWriter::Ws(writer);

let local_addr = None;
const PROTOCOL: Protocol = Protocol::Relay;

let (reader, writer, local_addr) = match &PROTOCOL {
Protocol::Websocket => {
let (writer, reader) = tokio_tungstenite_wasm::connect(self.url.as_str())
.await?
.split();

let reader = RelayConnReader::Ws(reader);
let writer = RelayConnWriter::Ws(writer);

let local_addr = None;

(reader, writer, local_addr)
}
Protocol::Relay => {
let tcp_stream = self.dial_url().await?;

let local_addr = tcp_stream
.local_addr()
.map_err(|e| ClientError::NoLocalAddr(e.to_string()))?;

debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected");

let response = if self.use_https() {
debug!("Starting TLS handshake");
let hostname = self
.tls_servername()
.ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?;
let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?;
debug!("tls_connector connect success");
Self::start_upgrade(tls_stream, &PROTOCOL).await?
} else {
debug!("Starting handshake");
Self::start_upgrade(tcp_stream, &PROTOCOL).await?
};

if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS {
error!(
"expected status 101 SWITCHING_PROTOCOLS, got: {}",
response.status()
);
return Err(ClientError::UnexpectedStatusCode(
hyper::StatusCode::SWITCHING_PROTOCOLS,
response.status(),
));
}

debug!("starting upgrade");
let upgraded = match hyper::upgrade::on(response).await {
Ok(upgraded) => upgraded,
Err(err) => {
warn!("upgrade failed: {:#}", err);
return Err(ClientError::Hyper(err));
}
};

debug!("connection upgraded");
let (reader, writer) =
downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?;

let reader = RelayConnReader::Relay(FramedRead::new(reader, DerpCodec));
let writer = RelayConnWriter::Relay(FramedWrite::new(writer, DerpCodec));

(reader, writer, Some(local_addr))
}
};

let (relay_client, receiver) =
RelayClientBuilder::new(self.secret_key.clone(), local_addr, reader, writer)
Expand All @@ -661,7 +675,10 @@ impl Actor {
}

/// Sends the HTTP upgrade request to the relay server.
async fn start_upgrade<T>(io: T) -> Result<hyper::Response<Incoming>, ClientError>
async fn start_upgrade<T>(
io: T,
protocol: &Protocol,
) -> Result<hyper::Response<Incoming>, ClientError>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
Expand All @@ -680,16 +697,24 @@ impl Actor {
}
.instrument(info_span!("http-driver")),
);

debug!("Sending upgrade request");
let req = Request::builder()
.uri("/derp")
.header(UPGRADE, WEBSOCKET_UPGRADE_PROTOCOL)
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
let mut builder = Request::builder().uri("/derp");

match protocol {
Protocol::Websocket => {
builder = builder
.header(UPGRADE, protocol.upgrade_header())
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13");
}
Protocol::Relay => builder = builder.header(UPGRADE, protocol.upgrade_header()),
}

let req = builder.body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
request_sender.send_request(req).await.map_err(From::from)
}

Expand Down
26 changes: 15 additions & 11 deletions iroh-net/src/relay/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ mod tests {
use super::*;

use crate::relay::{
client::ClientBuilder,
client::{ClientBuilder, RelayConnReader, RelayConnWriter},
codec::{recv_frame, Frame, FrameType},
http::{
server::Protocol,
Expand Down Expand Up @@ -655,12 +655,16 @@ mod tests {
fn make_test_client(secret_key: SecretKey) -> (tokio::io::DuplexStream, ClientBuilder) {
let (client, server) = tokio::io::duplex(10);
let (client_reader, client_writer) = tokio::io::split(client);
let (client_reader, client_writer) = fastwebsockets::after_handshake_split(
MaybeTlsStreamReader::Mem(client_reader),
MaybeTlsStreamWriter::Mem(client_writer),
fastwebsockets::Role::Client,
let (client_reader, client_writer) = (
RelayConnReader::Relay(FramedRead::new(
MaybeTlsStreamReader::Mem(client_reader),
DerpCodec,
)),
RelayConnWriter::Relay(FramedWrite::new(
MaybeTlsStreamWriter::Mem(client_writer),
DerpCodec,
)),
);
let (client_reader, client_writer) = todo!(); // TODO(matheus23) fix tests. Probably just use relay protocol here
(
server,
ClientBuilder::new(secret_key, None, client_reader, client_writer),
Expand All @@ -682,7 +686,7 @@ mod tests {
let handler = server.client_conn_handler(Default::default());
let handler_task = tokio::spawn(async move {
handler
.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a))
.accept(Protocol::Relay, MaybeTlsStream::Test(rw_a))
.await
});
let (client_a, mut client_receiver_a) = client_a_builder.build().await?;
Expand All @@ -695,7 +699,7 @@ mod tests {
let handler = server.client_conn_handler(Default::default());
let handler_task = tokio::spawn(async move {
handler
.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b))
.accept(Protocol::Relay, MaybeTlsStream::Test(rw_b))
.await
});
let (client_b, mut client_receiver_b) = client_b_builder.build().await?;
Expand Down Expand Up @@ -758,7 +762,7 @@ mod tests {
let handler = server.client_conn_handler(Default::default());
let handler_task = tokio::spawn(async move {
handler
.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a))
.accept(Protocol::Relay, MaybeTlsStream::Test(rw_a))
.await
});
let (client_a, mut client_receiver_a) = client_a_builder.build().await?;
Expand All @@ -771,7 +775,7 @@ mod tests {
let handler = server.client_conn_handler(Default::default());
let handler_task = tokio::spawn(async move {
handler
.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b))
.accept(Protocol::Relay, MaybeTlsStream::Test(rw_b))
.await
});
let (client_b, mut client_receiver_b) = client_b_builder.build().await?;
Expand Down Expand Up @@ -808,7 +812,7 @@ mod tests {
let handler = server.client_conn_handler(Default::default());
let handler_task = tokio::spawn(async move {
handler
.accept(Protocol::Websocket, MaybeTlsStream::Test(new_rw_b))
.accept(Protocol::Relay, MaybeTlsStream::Test(new_rw_b))
.await
});
let (new_client_b, mut new_client_receiver_b) = new_client_b_builder.build().await?;
Expand Down

0 comments on commit d2aca66

Please sign in to comment.