Skip to content

Commit

Permalink
start implementing proxy support
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed May 15, 2024
1 parent 9d71fd8 commit 5fd4f58
Showing 1 changed file with 205 additions and 26 deletions.
231 changes: 205 additions & 26 deletions iroh-net/src/relay/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use anyhow::bail;
use bytes::Bytes;
use futures_lite::future::Boxed as BoxFuture;
use http_body_util::Empty;
use hyper::body::Incoming;
use hyper::header::UPGRADE;
use hyper::upgrade::{Parts, Upgraded};
use hyper::Request;
use hyper_util::rt::TokioIo;
use rand::Rng;
use rustls::client::Resumption;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
Expand Down Expand Up @@ -159,6 +162,7 @@ struct Actor {
pings: PingTracker,
ping_tasks: JoinSet<()>,
dns_resolver: DnsResolver,
proxy_url: Option<Url>,
}

#[derive(Default, Debug)]
Expand Down Expand Up @@ -200,6 +204,8 @@ pub struct ClientBuilder {
/// Allow self-signed certificates from relay servers
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_cert_verify: bool,
/// HTTP Proxy
proxy_url: Option<Url>,
}

impl std::fmt::Debug for ClientBuilder {
Expand All @@ -224,6 +230,7 @@ impl ClientBuilder {
url: url.into(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_cert_verify: false,
proxy_url: None,
}
}

Expand Down Expand Up @@ -275,6 +282,12 @@ impl ClientBuilder {
self
}

/// Set a proxy url to proxy all HTTP(S) traffic through.
pub fn proxy_url(mut self, url: Url) -> Self {
self.proxy_url.replace(url);
self
}

/// Build the [`Client`]
pub fn build(self, key: SecretKey, dns_resolver: DnsResolver) -> (Client, ClientReceiver) {
// TODO: review TLS config
Expand Down Expand Up @@ -316,6 +329,7 @@ impl ClientBuilder {
url: self.url,
tls_connector,
dns_resolver,
proxy_url: self.proxy_url,
};

let (msg_sender, inbox) = mpsc::channel(64);
Expand Down Expand Up @@ -762,18 +776,6 @@ impl Actor {
.and_then(|s| rustls::ServerName::try_from(s).ok())
}

fn url_port(&self) -> Option<u16> {
if let Some(port) = self.url.port() {
return Some(port);
}

match self.url.scheme() {
"http" => Some(80),
"https" => Some(443),
_ => None,
}
}

fn use_https(&self) -> bool {
// only disable https if we are explicitly dialing a http url
if self.url.scheme() == "http" {
Expand All @@ -782,14 +784,26 @@ impl Actor {
true
}

async fn dial_url(&self) -> Result<TcpStream, ClientError> {
debug!(%self.url, "dial url");
async fn dial_url(&self) -> Result<ProxyStream, ClientError> {
if let Some(ref proxy) = self.proxy_url {
let (stream, local_addr, peer_addr) = self.dial_url_proxy(proxy.clone()).await?;
Ok(ProxyStream::Proxied {
stream,
local_addr,
peer_addr,
})
} else {
let stream = self.dial_url_direct().await?;
Ok(ProxyStream::Raw(stream))
}
}

async fn dial_url_direct(&self) -> Result<TcpStream, ClientError> {
debug!(%self.url, "dial url");
let prefer_ipv6 = self.prefer_ipv6().await;
let dst_ip = resolve_host(&self.dns_resolver, &self.url, prefer_ipv6).await?;

let port = self
.url_port()
let port = url_port(&self.url)
.ok_or_else(|| ClientError::InvalidUrl("missing url port".into()))?;
let addr = SocketAddr::new(dst_ip, port);

Expand All @@ -808,6 +822,61 @@ impl Actor {
Ok(tcp_stream)
}

async fn dial_url_proxy(
&self,
proxy_url: Url,
) -> Result<(TokioIo<hyper::upgrade::Upgraded>, SocketAddr, SocketAddr), ClientError> {
debug!(%self.url, %proxy_url, "dial url via proxy");

// Resolve proxy DNS
let prefer_ipv6 = self.prefer_ipv6().await;
let proxy_ip = resolve_host(&self.dns_resolver, &proxy_url, prefer_ipv6).await?;

let proxy_port = url_port(&proxy_url)
.ok_or_else(|| ClientError::InvalidUrl("missing proxy url port".into()))?;
let proxy_addr = SocketAddr::new(proxy_ip, proxy_port);

debug!(%proxy_addr, "connecting to proxy");

// TODO: add TLS support

let tcp_stream = tokio::time::timeout(DIAL_NODE_TIMEOUT, async move {
TcpStream::connect(proxy_addr).await
})
.await
.map_err(|_| ClientError::ConnectTimeout)?
.map_err(ClientError::DialIO)?;

tcp_stream.set_nodelay(true)?;

let local_addr = tcp_stream.local_addr()?;
let peer_addr = tcp_stream.peer_addr()?;

// Establish Proxy Tunnel
let req = Request::builder()
.uri(self.url.to_string())
.method("CONNECT")
.body(Empty::<Bytes>::new())?;

let io = TokioIo::new(tcp_stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.with_upgrades().await {
println!("Connection failed: {:?}", err);
}
});

let res = sender.send_request(req).await?;
if !res.status().is_success() {
panic!("Our server didn't CONNECT: {}", res.status());
}

let upgraded = hyper::upgrade::on(res).await?;
let tunnel = TokioIo::new(upgraded);

Ok((tunnel, local_addr, peer_addr))
}

/// Reports whether IPv4 dials should be slightly
/// delayed to give IPv6 a better chance of winning dial races.
/// Implementations should only return true if IPv6 is expected
Expand Down Expand Up @@ -887,23 +956,40 @@ fn downcast_upgrade(
Box<dyn AsyncRead + Unpin + Send + Sync + 'static>,
Box<dyn AsyncWrite + Unpin + Send + Sync + 'static>,
)> {
match upgraded.downcast::<hyper_util::rt::TokioIo<tokio::net::TcpStream>>() {
match upgraded.downcast::<TokioIo<ProxyStream>>() {
Ok(Parts { read_buf, io, .. }) => {
let (reader, writer) = tokio::io::split(io.into_inner());
// Prepend data to the reader to avoid data loss
let reader = std::io::Cursor::new(read_buf).chain(reader);

Ok((Box::new(reader), Box::new(writer)))
let inner = io.into_inner();
match inner {
ProxyStream::Raw(tcp_stream) => {
let (reader, writer) = tokio::io::split(tcp_stream);
// Prepend data to the reader to avoid data loss
let reader = std::io::Cursor::new(read_buf).chain(reader);
Ok((Box::new(reader), Box::new(writer)))
}
ProxyStream::Proxied { stream, .. } => match stream.into_inner().downcast::<TokioIo<TcpStream>>() {
Ok(Parts { read_buf: read_buf_inner, io, .. }) => {
let (reader, writer) = tokio::io::split(io.into_inner());
// Prepend data to the reader to avoid data loss
let reader = std::io::Cursor::new(read_buf_inner).chain(std::io::Cursor::new(read_buf)).chain(reader);
Ok((Box::new(reader), Box::new(writer)))
}
Err(_) => {
bail!("could not downcast");
}
}
}
}
Err(upgraded) => {
if let Ok(Parts { read_buf, io, .. }) =
upgraded.downcast::<hyper_util::rt::TokioIo<tokio_rustls::client::TlsStream<tokio::net::TcpStream>>>()
if let Ok(Parts { read_buf, io, .. }) = upgraded
.downcast::<TokioIo<tokio_rustls::client::TlsStream<ProxyStream>>>()
{
let (reader, writer) = tokio::io::split(io.into_inner());
let inner = io.into_inner();
let (reader, writer) = tokio::io::split(inner);
// Prepend data to the reader to avoid data loss
let reader = std::io::Cursor::new(read_buf).chain(reader);

return Ok((Box::new(reader), Box::new(writer)));
todo!()
// return Ok((Box::new(reader), Box::new(writer)));
}

bail!(
Expand Down Expand Up @@ -932,6 +1018,99 @@ impl rustls::client::ServerCertVerifier for NoCertVerifier {
}
}

fn url_port(url: &Url) -> Option<u16> {
if let Some(port) = url.port() {
return Some(port);
}

match url.scheme() {
"http" => Some(80),
"https" => Some(443),
_ => None,
}
}

enum ProxyStream {
Raw(TcpStream),
Proxied {
stream: TokioIo<hyper::upgrade::Upgraded>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
},
}

impl AsyncRead for ProxyStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut *self {
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Proxied { stream, .. } => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for ProxyStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut *self {
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Proxied { stream, .. } => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
Self::Proxied { stream, .. } => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut *self {
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Proxied { stream, .. } => Pin::new(stream).poll_shutdown(cx),
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut *self {
Self::Raw(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
Self::Proxied { stream, .. } => Pin::new(stream).poll_write_vectored(cx, bufs),
}
}
}

impl ProxyStream {
fn local_addr(&self) -> std::io::Result<SocketAddr> {
match self {
Self::Raw(s) => s.local_addr(),
Self::Proxied { local_addr, .. } => Ok(*local_addr),
}
}

fn peer_addr(&self) -> std::io::Result<SocketAddr> {
match self {
Self::Raw(s) => s.peer_addr(),
Self::Proxied { peer_addr, .. } => Ok(*peer_addr),
}
}
}

#[cfg(test)]
mod tests {
use anyhow::Result;
Expand Down

0 comments on commit 5fd4f58

Please sign in to comment.