Skip to content

Commit

Permalink
Implemented quic transport
Browse files Browse the repository at this point in the history
Signed-off-by: Marlon Baeten <[email protected]>
  • Loading branch information
marlonbaeten committed Jun 11, 2024
1 parent 332aaaa commit bc34982
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 76 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tokio-rustls = "0.26"
rustls-pki-types = "1.7"
rustls-native-certs = "0.7"
rustls-pemfile = "2.1"
quinn = "0.11"
# resolve
reqwest = { version = "0.12.3", default-features = false, features = ["rustls-tls-native-roots", "json", "stream", "charset", "http2", "macos-system-configuration"] }
# serialize
Expand Down
4 changes: 2 additions & 2 deletions examples/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ async fn read_database(

trace!("opened database {database_name}");

vault.destroy().await?;
// vault.destroy().await?;

let vault = Vault::new_sqlite(database_name, password.as_bytes()).await?;
// let vault = Vault::new_sqlite(database_name, password.as_bytes()).await?;

Ok((vault, db, aliases))
}
Expand Down
5 changes: 3 additions & 2 deletions tsp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ async = [
"dep:tokio-rustls",
"dep:rustls-pki-types",
"dep:rustls-native-certs",
"dep:rustls-pemfile"
"dep:rustls-pemfile",
"dep:quinn"
]
resolve = [
"serialize",
Expand Down Expand Up @@ -67,6 +68,7 @@ tokio-rustls = { workspace = true, optional = true }
rustls-pki-types = { workspace = true, optional = true }
rustls-native-certs = { workspace = true, optional = true }
rustls-pemfile = { workspace = true, optional = true }
quinn = { workspace = true, optional = true }
# resolve
reqwest = { workspace = true, optional = true }
# serialize
Expand All @@ -79,4 +81,3 @@ arbitrary ={ workspace = true, optional = true }
[dev-dependencies]
serial_test = { version = "3.0" }
arbitrary ={ workspace = true }
tracing-subscriber = { workspace = true }
2 changes: 0 additions & 2 deletions tsp/src/cesr/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,6 @@ mod test {
let message = Base64UrlUnpadded::decode_vec("-EABXAAA7VIDAAAEZGlkOnRlc3Q6Ym9i8VIDAAAFAGRpZDp0ZXN0OmFsaWNl6BAEAABleHRyYSBkYXRh4CAXScvzIiBCgfOu9jHtGwd1qN-KlMB7uhFbE9YOSyTmnp9yziA1LVPdQmST27yjuDRTlxeRo7H7gfuaGFY4iyf2EsfiqvEg0BBNDbKoW0DDczGxj7rNWKH_suyj18HCUxMZ6-mDymZdNhHZIS8zIstC9Kxv5Q-GxmI-1v4SNbeCemuCMBzMPogK").unwrap();
let parts = open_message_into_parts(&message).unwrap();

dbg!(&parts);

assert_eq!(parts.prefix.prefix.len(), 6);
assert_eq!(parts.sender.data.len(), 12);
assert_eq!(parts.receiver.unwrap().data.len(), 14);
Expand Down
2 changes: 0 additions & 2 deletions tsp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ async fn test_large_messages() {
.await
.unwrap();

dbg!(sent_message.len());

// receive a message
let crate::definitions::ReceivedTspMessage::GenericMessage {
message,
Expand Down
4 changes: 4 additions & 0 deletions tsp/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub enum TransportError {
Http(String, reqwest::Error),
#[error("connection to '{0}' failed: {1}")]
Connection(String, std::io::Error),
#[error("connection to '{0}' failed: {1}")]
QuicConnection(String, quinn::ConnectError),
#[error("invalid address '{0}'")]
InvalidTransportAddress(String),
#[error("invalid transport scheme '{0}'")]
Expand All @@ -22,4 +24,6 @@ pub enum TransportError {
TLS(#[from] rustls::Error),
#[error("internel error")]
Internal,
#[error("could not listen on random UDP port")]
ListenPort,
}
5 changes: 2 additions & 3 deletions tsp/src/transport/http.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::definitions::TSPStream;
use async_stream::stream;
use futures::StreamExt;
use tokio_util::bytes::BytesMut;
use url::Url;

use super::TransportError;
Expand All @@ -28,7 +27,7 @@ pub(crate) async fn send_message(tsp_message: &[u8], url: &Url) -> Result<(), Tr

pub(crate) async fn receive_messages(
address: &Url,
) -> Result<TSPStream<BytesMut, TransportError>, TransportError> {
) -> Result<TSPStream<Vec<u8>, TransportError>, TransportError> {
let mut ws_address = address.clone();

match address.scheme() {
Expand All @@ -49,7 +48,7 @@ pub(crate) async fn receive_messages(
while let Some(Ok(msg)) = receiver.next().await {
match msg {
tokio_tungstenite::tungstenite::Message::Binary(b) => {
yield Ok(BytesMut::from(&b[..]));
yield Ok(b);
}
m => {
yield Err(TransportError::InvalidMessageReceived(
Expand Down
8 changes: 5 additions & 3 deletions tsp/src/transport/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::definitions::TSPStream;
use tokio_util::bytes::BytesMut;
use url::Url;

pub mod error;

mod http;
pub mod tcp;
mod quic;
mod tcp;
mod tls;

pub use error::TransportError;
Expand All @@ -14,6 +14,7 @@ pub async fn send_message(transport: &Url, tsp_message: &[u8]) -> Result<(), Tra
match transport.scheme() {
tcp::SCHEME => tcp::send_message(tsp_message, transport).await,
tls::SCHEME => tls::send_message(tsp_message, transport).await,
quic::SCHEME => quic::send_message(tsp_message, transport).await,
http::SCHEME_HTTP => http::send_message(tsp_message, transport).await,
http::SCHEME_HTTPS => http::send_message(tsp_message, transport).await,
_ => Err(TransportError::InvalidTransportScheme(
Expand All @@ -24,10 +25,11 @@ pub async fn send_message(transport: &Url, tsp_message: &[u8]) -> Result<(), Tra

pub async fn receive_messages(
transport: &Url,
) -> Result<TSPStream<BytesMut, TransportError>, TransportError> {
) -> Result<TSPStream<Vec<u8>, TransportError>, TransportError> {
match transport.scheme() {
tcp::SCHEME => tcp::receive_messages(transport).await,
tls::SCHEME => tls::receive_messages(transport).await,
quic::SCHEME => quic::receive_messages(transport).await,
http::SCHEME_HTTP => http::receive_messages(transport).await,
http::SCHEME_HTTPS => http::receive_messages(transport).await,
_ => Err(TransportError::InvalidTransportScheme(
Expand Down
177 changes: 177 additions & 0 deletions tsp/src/transport/quic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
use async_stream::stream;
use lazy_static::lazy_static;
use quinn::{
crypto::rustls::{QuicClientConfig, QuicServerConfig},
ClientConfig, Endpoint,
};
use std::sync::Arc;
use tokio::sync::mpsc;
use url::Url;

use super::TransportError;
use crate::definitions::TSPStream;

pub(crate) const SCHEME: &str = "quic";

pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];

lazy_static! {
static ref QUIC_CONFIG: ClientConfig = {
let mut config = super::tls::TLS_CONFIG.clone().deref().clone();
config.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect();

quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(config)
.expect("could not convert TLS config to QUIC config"),
))
};
}

/// Send a message over QUIC
/// Connects to the specified transport address and sends the message.
/// Note that a new connection is opened for each message.
pub(crate) async fn send_message(tsp_message: &[u8], url: &Url) -> Result<(), TransportError> {
let addresses = url
.socket_addrs(|| None)
.map_err(|_| TransportError::InvalidTransportAddress(url.to_string()))?;

let Some(address) = addresses.first().cloned() else {
return Err(TransportError::InvalidTransportAddress(url.to_string()));
};

let domain = url
.domain()
.ok_or(TransportError::InvalidTransportAddress(format!(
"could not resolve {url} to a domain"
)))?
.to_owned();

// passing 0 as port number opens a random port
let mut endpoint =
Endpoint::client(([127, 0, 0, 1], 0).into()).map_err(|_| TransportError::ListenPort)?;

endpoint.set_default_client_config(QUIC_CONFIG.clone());

let connection = endpoint
.connect(address, &domain)
.map_err(|e| TransportError::QuicConnection(address.to_string(), e))?
.await
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

let mut send = connection
.open_uni()
.await
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

send.write_all(tsp_message)
.await
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

send.finish()
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

send.stopped()
.await
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

connection.close(0u32.into(), b"done");

Ok(())
}

/// Receive (multiple) messages over QUIC
/// Listens on the specified transport port and yields messages as they arrive
/// This function handles multiple connections and messages and
/// combines them in a single stream. It uses an internal queue of 16 messages.
pub(crate) async fn receive_messages(
address: &Url,
) -> Result<TSPStream<Vec<u8>, TransportError>, TransportError> {
let addresses = address
.socket_addrs(|| None)
.map_err(|_| TransportError::InvalidTransportAddress(address.to_string()))?;

let Some(address) = addresses.first().cloned() else {
return Err(TransportError::InvalidTransportAddress(address.to_string()));
};

let (cert, key) = super::tls::load_certificate()?;

let mut server_crypto =
rustls::ServerConfig::builder_with_provider(super::tls::CRYPTO_PROVIDER.clone())
.with_safe_default_protocol_versions()?
.with_no_client_auth()
.with_single_cert(cert, key)?;

server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect();

let server_config = quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(server_crypto).map_err(|_| TransportError::Internal)?,
));

let endpoint = Endpoint::server(server_config, address)
.map_err(|e| TransportError::Connection(address.to_string(), e))?;

let (tx, mut rx) = mpsc::channel::<Result<Vec<u8>, TransportError>>(16);

tokio::spawn(async move {
while let Some(incoming_conn) = endpoint.accept().await {
let tx = tx.clone();

tokio::spawn(async move {
let conn = incoming_conn
.await
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?;

let receive = conn.accept_uni().await;

let mut receive = match receive {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
return Ok(());
}
Err(e) => {
return Err(TransportError::Connection(address.to_string(), e.into()));
}
Ok(s) => s,
};

let message = receive.read_to_end(8 * 1024).await.map_err(|_| {
TransportError::InvalidMessageReceived(format!(
"message from {address} is too long",
))
});

tx.send(message)
.await
.map_err(|_| TransportError::Internal)?;

Ok(())
});
}
});

Ok(Box::pin(stream! {
while let Some(item) = rx.recv().await {
yield item;
}
}))
}

#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;

#[tokio::test]
async fn test_quic_transport() {
let url = Url::parse("quic://localhost:3737").unwrap();
let message = b"Hello, world!";

let mut incoming_stream = receive_messages(&url).await.unwrap();

send_message(message, &url).await.unwrap();

let received_message = incoming_stream.next().await.unwrap().unwrap();

assert_eq!(message, received_message.as_slice());
}
}
11 changes: 4 additions & 7 deletions tsp/src/transport/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use async_stream::stream;
use futures::StreamExt;
use tokio::{io::AsyncWriteExt, net::TcpListener};
use tokio_util::{
bytes::BytesMut,
codec::{BytesCodec, Framed},
};
use tokio_util::codec::{BytesCodec, Framed};
use url::Url;

use super::{TSPStream, TransportError};
Expand Down Expand Up @@ -38,7 +35,7 @@ pub(crate) async fn send_message(tsp_message: &[u8], url: &Url) -> Result<(), Tr
/// Listens on the specified transport port and yields messages as they arrive
pub(crate) async fn receive_messages(
address: &Url,
) -> Result<TSPStream<BytesMut, TransportError>, TransportError> {
) -> Result<TSPStream<Vec<u8>, TransportError>, TransportError> {
let addresses = address
.socket_addrs(|| None)
.map_err(|_| TransportError::InvalidTransportAddress(address.to_string()))?;
Expand All @@ -56,7 +53,7 @@ pub(crate) async fn receive_messages(
let mut messages = Framed::new(stream, BytesCodec::new());

while let Some(m) = messages.next().await {
yield m.map_err(|e| TransportError::Connection(addr.to_string(), e));
yield m.map(|m| m.to_vec()).map_err(|e| TransportError::Connection(addr.to_string(), e));
}
}
}))
Expand All @@ -78,6 +75,6 @@ mod test {
send_message(message, &url).await.unwrap();
let received_message = incoming_stream.next().await.unwrap().unwrap();

assert_eq!(message, received_message.as_ref());
assert_eq!(message, received_message.as_slice());
}
}
Loading

0 comments on commit bc34982

Please sign in to comment.