Skip to content

Commit

Permalink
refactor: Adjust mod exposure, rename & cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus23 committed Jun 21, 2024
1 parent 8803fdd commit fdd1214
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 122 deletions.
44 changes: 17 additions & 27 deletions iroh-net/src/relay/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ impl ClientReceiver {

#[derive(derive_more::Debug)]
pub struct InnerClient {
// our local addrs
/// Our local address, if known.
///
/// `None` if we don't control the connection establishment, e.g. in browsers.
local_addr: Option<SocketAddr>,
/// Channel on which to communicate to the server. The associated [`mpsc::Receiver`] will close
/// if there is ever an error writing to the server.
Expand Down Expand Up @@ -191,7 +193,7 @@ fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
}

/// The kinds of messages we can send to the [`super::server::Server`]
#[derive(derive_more::Debug)]
#[derive(Debug)]
enum ClientWriterMessage {
/// Send a packet (addressed to the [`PublicKey`]) to the server
Packet((PublicKey, Bytes)),
Expand All @@ -210,19 +212,18 @@ enum ClientWriterMessage {
///
/// Shutsdown when you send a [`ClientWriterMessage::Shutdown`], or if there is an error writing to
/// the server.
struct ClientWriter<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> {
struct ClientWriter {
recv_msgs: mpsc::Receiver<ClientWriterMessage>,
writer: W,
writer: RelayConnWriter,
rate_limiter: Option<RateLimiter>,
}

impl<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> ClientWriter<W> {
impl ClientWriter {
async fn run(mut self) -> Result<()> {
while let Some(msg) = self.recv_msgs.recv().await {
match msg {
ClientWriterMessage::Packet((key, bytes)) => {
send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
self.writer.flush().await?;
}
ClientWriterMessage::Pong(data) => {
write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
Expand All @@ -246,33 +247,22 @@ impl<W: Sink<Frame, Error = std::io::Error> + Unpin + 'static> ClientWriter<W> {
}
}

/// The Builder returns a [`Client`] starts a [`ClientWriter`] run task.
/// The Builder returns a [`Client`] and a started [`ClientWriter`] run task.
pub struct ClientBuilder {
secret_key: SecretKey,
reader: RelayConnReader,
writer: RelayConnWriter,
local_addr: Option<SocketAddr>,
}

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnReader {
Relay(
#[debug("FramedRead<MaybeTlsStreamReader, DerpCodec>")]
FramedRead<MaybeTlsStreamReader, DerpCodec>,
),
Ws(#[debug("SplitStream<WebSocketStream>")] SplitStream<WebSocketStream>),
Derp(FramedRead<MaybeTlsStreamReader, DerpCodec>),
Ws(SplitStream<WebSocketStream>),
}

#[derive(derive_more::Debug)]
pub(crate) enum RelayConnWriter {
Relay(
#[debug("FramedWrite<MaybeTlsStreamWriter, DerpCodec>")]
FramedWrite<MaybeTlsStreamWriter, DerpCodec>,
),
Ws(
#[debug("SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>")]
SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>,
),
Derp(FramedWrite<MaybeTlsStreamWriter, DerpCodec>),
Ws(SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>),
}

fn tung_wasm_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error {
Expand All @@ -284,7 +274,7 @@ impl Stream for RelayConnReader {

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::Relay(ref mut ws) => Pin::new(ws).poll_next(cx),
Self::Derp(ref mut ws) => Pin::new(ws).poll_next(cx),
Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
Poll::Ready(Some(item)) => match Frame::from_wasm_ws_message(item) {
Some(frame) => Poll::Ready(Some(frame)),
Expand All @@ -303,14 +293,14 @@ impl Sink<Frame> for RelayConnWriter {

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => Pin::new(ws).poll_ready(cx),
Self::Derp(ref mut ws) => Pin::new(ws).poll_ready(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_wasm_to_io_err),
}
}

fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
match *self {
Self::Relay(ref mut ws) => Pin::new(ws).start_send(item),
Self::Derp(ref mut ws) => Pin::new(ws).start_send(item),
Self::Ws(ref mut ws) => Pin::new(ws)
.start_send(item.into_wasm_ws_message()?)
.map_err(tung_wasm_to_io_err),
Expand All @@ -319,14 +309,14 @@ impl Sink<Frame> for RelayConnWriter {

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => Pin::new(ws).poll_flush(cx),
Self::Derp(ref mut ws) => Pin::new(ws).poll_flush(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_wasm_to_io_err),
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Relay(ref mut ws) => Pin::new(ws).poll_close(cx),
Self::Derp(ref mut ws) => Pin::new(ws).poll_close(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_wasm_to_io_err),
}
}
Expand Down
4 changes: 2 additions & 2 deletions iroh-net/src/relay/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
//! upgrades.
//!
mod client;
pub(crate) mod server;
mod server;
pub(crate) mod streams;

pub use self::client::{Client, ClientBuilder, ClientError, ClientReceiver};
pub use self::server::{Server, ServerBuilder, ServerHandle, TlsAcceptor, TlsConfig};
pub use self::server::{Protocol, Server, ServerBuilder, ServerHandle, TlsAcceptor, TlsConfig};

pub(crate) const HTTP_UPGRADE_PROTOCOL: &str = "iroh derp http";
pub(crate) const WEBSOCKET_UPGRADE_PROTOCOL: &str = "websocket";
Expand Down
158 changes: 77 additions & 81 deletions iroh-net/src/relay/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub enum ClientError {
/// The inner actor is gone, likely means things are shutdown.
#[error("actor gone")]
ActorGone,
/// There was an error related to websockets
/// An error related to websockets, either errors with parsing ws messages or the handshake
#[error("websocket error: {0}")]
WebsocketError(#[from] tokio_tungstenite_wasm::Error),
}
Expand Down Expand Up @@ -588,73 +588,17 @@ impl Actor {
}

async fn connect_0(&self) -> Result<(RelayClient, RelayClientReceiver), ClientError> {
// We determine which protocol to use for relays via the URL scheme: ws(s) vs. http(s)
let protocol = Protocol::from_url_scheme(&self.url);

let (reader, writer, local_addr) = match &protocol {
Protocol::Websocket => {
let mut dial_url = (*self.url).clone();
dial_url.set_path("/derp");

debug!(%dial_url, "Dialing relay by websocket");

let (writer, reader) = tokio_tungstenite_wasm::connect(dial_url).await?.split();

let reader = RelayConnReader::Ws(reader);
let writer = RelayConnWriter::Ws(writer);

let (reader, writer) = self.connect_ws().await?;
let local_addr = None;

(reader, writer, local_addr)
}
Protocol::Relay => {
let tcp_stream = self.dial_url().await?;

let local_addr = tcp_stream
.local_addr()
.map_err(|e| ClientError::NoLocalAddr(e.to_string()))?;

debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected");

let response = if self.use_https() {
debug!("Starting TLS handshake");
let hostname = self
.tls_servername()
.ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?;
let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?;
debug!("tls_connector connect success");
Self::start_upgrade(tls_stream, &protocol).await?
} else {
debug!("Starting handshake");
Self::start_upgrade(tcp_stream, &protocol).await?
};

if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS {
error!(
"expected status 101 SWITCHING_PROTOCOLS, got: {}",
response.status()
);
return Err(ClientError::UnexpectedStatusCode(
hyper::StatusCode::SWITCHING_PROTOCOLS,
response.status(),
));
}

debug!("starting upgrade");
let upgraded = match hyper::upgrade::on(response).await {
Ok(upgraded) => upgraded,
Err(err) => {
warn!("upgrade failed: {:#}", err);
return Err(ClientError::Hyper(err));
}
};

debug!("connection upgraded");
let (reader, writer) =
downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?;

let reader = RelayConnReader::Relay(FramedRead::new(reader, DerpCodec));
let writer = RelayConnWriter::Relay(FramedWrite::new(writer, DerpCodec));

let (reader, writer, local_addr) = self.connect_derp().await?;
(reader, writer, Some(local_addr))
}
};
Expand All @@ -674,11 +618,76 @@ impl Actor {
Ok((relay_client, receiver))
}

async fn connect_ws(&self) -> Result<(RelayConnReader, RelayConnWriter), ClientError> {
let mut dial_url = (*self.url).clone();
dial_url.set_path("/derp");

debug!(%dial_url, "Dialing relay by websocket");

let (writer, reader) = tokio_tungstenite_wasm::connect(dial_url).await?.split();

let reader = RelayConnReader::Ws(reader);
let writer = RelayConnWriter::Ws(writer);

Ok((reader, writer))
}

async fn connect_derp(
&self,
) -> Result<(RelayConnReader, RelayConnWriter, SocketAddr), ClientError> {
let tcp_stream = self.dial_url().await?;

let local_addr = tcp_stream
.local_addr()
.map_err(|e| ClientError::NoLocalAddr(e.to_string()))?;

debug!(server_addr = ?tcp_stream.peer_addr(), %local_addr, "TCP stream connected");

let response = if self.use_https() {
debug!("Starting TLS handshake");
let hostname = self
.tls_servername()
.ok_or_else(|| ClientError::InvalidUrl("No tls servername".into()))?;
let tls_stream = self.tls_connector.connect(hostname, tcp_stream).await?;
debug!("tls_connector connect success");
Self::start_upgrade(tls_stream).await?
} else {
debug!("Starting handshake");
Self::start_upgrade(tcp_stream).await?
};

if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS {
error!(
"expected status 101 SWITCHING_PROTOCOLS, got: {}",
response.status()
);
return Err(ClientError::UnexpectedStatusCode(
hyper::StatusCode::SWITCHING_PROTOCOLS,
response.status(),
));
}

debug!("starting upgrade");
let upgraded = match hyper::upgrade::on(response).await {
Ok(upgraded) => upgraded,
Err(err) => {
warn!("upgrade failed: {:#}", err);
return Err(ClientError::Hyper(err));
}
};

debug!("connection upgraded");
let (reader, writer) =
downcast_upgrade(upgraded).map_err(|e| ClientError::Upgrade(e.to_string()))?;

let reader = RelayConnReader::Derp(FramedRead::new(reader, DerpCodec));
let writer = RelayConnWriter::Derp(FramedWrite::new(writer, DerpCodec));

Ok((reader, writer, local_addr))
}

/// Sends the HTTP upgrade request to the relay server.
async fn start_upgrade<T>(
io: T,
protocol: &Protocol,
) -> Result<hyper::Response<Incoming>, ClientError>
async fn start_upgrade<T>(io: T) -> Result<hyper::Response<Incoming>, ClientError>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
Expand All @@ -697,24 +706,11 @@ impl Actor {
}
.instrument(info_span!("http-driver")),
);

debug!("Sending upgrade request");
let mut builder = Request::builder().uri("/derp");

match protocol {
Protocol::Websocket => {
builder = builder
.header(UPGRADE, protocol.upgrade_header())
.header(
"Sec-WebSocket-Key",
tungstenite::handshake::client::generate_key(),
)
.header("Sec-WebSocket-Version", "13");
}
Protocol::Relay => builder = builder.header(UPGRADE, protocol.upgrade_header()),
}

let req = builder.body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
let req = Request::builder()
.uri("/derp")
.header(UPGRADE, Protocol::Relay.upgrade_header())
.body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
request_sender.send_request(req).await.map_err(From::from)
}

Expand Down
21 changes: 9 additions & 12 deletions iroh-net/src/relay/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use std::task::{Context, Poll};
use std::time::Duration;

use anyhow::{bail, Context as _, Result};
use futures_lite::Stream;
use futures_sink::Sink;
use futures_util::Stream;
use hyper::HeaderMap;
use iroh_metrics::core::UsageStatsReport;
use iroh_metrics::{inc, report_usage_stats};
Expand All @@ -23,7 +23,7 @@ use tungstenite::protocol::Role;
use crate::key::{PublicKey, SecretKey};

use super::codec::Frame;
use super::http::server::Protocol;
use super::http::Protocol;
use super::{
client_conn::ClientConnBuilder,
clients::Clients,
Expand Down Expand Up @@ -657,16 +657,13 @@ mod tests {
fn make_test_client(secret_key: SecretKey) -> (tokio::io::DuplexStream, ClientBuilder) {
let (client, server) = tokio::io::duplex(10);
let (client_reader, client_writer) = tokio::io::split(client);
let (client_reader, client_writer) = (
RelayConnReader::Relay(FramedRead::new(
MaybeTlsStreamReader::Mem(client_reader),
DerpCodec,
)),
RelayConnWriter::Relay(FramedWrite::new(
MaybeTlsStreamWriter::Mem(client_writer),
DerpCodec,
)),
);

let client_reader = MaybeTlsStreamReader::Mem(client_reader);
let client_writer = MaybeTlsStreamWriter::Mem(client_writer);

let client_reader = RelayConnReader::Derp(FramedRead::new(client_reader, DerpCodec));
let client_writer = RelayConnWriter::Derp(FramedWrite::new(client_writer, DerpCodec));

(
server,
ClientBuilder::new(secret_key, None, client_reader, client_writer),
Expand Down

0 comments on commit fdd1214

Please sign in to comment.