-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Marlon Baeten <[email protected]>
- Loading branch information
1 parent
332aaaa
commit bc34982
Showing
12 changed files
with
318 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
use async_stream::stream; | ||
use lazy_static::lazy_static; | ||
use quinn::{ | ||
crypto::rustls::{QuicClientConfig, QuicServerConfig}, | ||
ClientConfig, Endpoint, | ||
}; | ||
use std::sync::Arc; | ||
use tokio::sync::mpsc; | ||
use url::Url; | ||
|
||
use super::TransportError; | ||
use crate::definitions::TSPStream; | ||
|
||
pub(crate) const SCHEME: &str = "quic"; | ||
|
||
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; | ||
|
||
lazy_static! { | ||
static ref QUIC_CONFIG: ClientConfig = { | ||
let mut config = super::tls::TLS_CONFIG.clone().deref().clone(); | ||
config.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); | ||
|
||
quinn::ClientConfig::new(Arc::new( | ||
QuicClientConfig::try_from(config) | ||
.expect("could not convert TLS config to QUIC config"), | ||
)) | ||
}; | ||
} | ||
|
||
/// Send a message over QUIC | ||
/// Connects to the specified transport address and sends the message. | ||
/// Note that a new connection is opened for each message. | ||
pub(crate) async fn send_message(tsp_message: &[u8], url: &Url) -> Result<(), TransportError> { | ||
let addresses = url | ||
.socket_addrs(|| None) | ||
.map_err(|_| TransportError::InvalidTransportAddress(url.to_string()))?; | ||
|
||
let Some(address) = addresses.first().cloned() else { | ||
return Err(TransportError::InvalidTransportAddress(url.to_string())); | ||
}; | ||
|
||
let domain = url | ||
.domain() | ||
.ok_or(TransportError::InvalidTransportAddress(format!( | ||
"could not resolve {url} to a domain" | ||
)))? | ||
.to_owned(); | ||
|
||
// passing 0 as port number opens a random port | ||
let mut endpoint = | ||
Endpoint::client(([127, 0, 0, 1], 0).into()).map_err(|_| TransportError::ListenPort)?; | ||
|
||
endpoint.set_default_client_config(QUIC_CONFIG.clone()); | ||
|
||
let connection = endpoint | ||
.connect(address, &domain) | ||
.map_err(|e| TransportError::QuicConnection(address.to_string(), e))? | ||
.await | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
let mut send = connection | ||
.open_uni() | ||
.await | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
send.write_all(tsp_message) | ||
.await | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
send.finish() | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
send.stopped() | ||
.await | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
connection.close(0u32.into(), b"done"); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Receive (multiple) messages over QUIC | ||
/// Listens on the specified transport port and yields messages as they arrive | ||
/// This function handles multiple connections and messages and | ||
/// combines them in a single stream. It uses an internal queue of 16 messages. | ||
pub(crate) async fn receive_messages( | ||
address: &Url, | ||
) -> Result<TSPStream<Vec<u8>, TransportError>, TransportError> { | ||
let addresses = address | ||
.socket_addrs(|| None) | ||
.map_err(|_| TransportError::InvalidTransportAddress(address.to_string()))?; | ||
|
||
let Some(address) = addresses.first().cloned() else { | ||
return Err(TransportError::InvalidTransportAddress(address.to_string())); | ||
}; | ||
|
||
let (cert, key) = super::tls::load_certificate()?; | ||
|
||
let mut server_crypto = | ||
rustls::ServerConfig::builder_with_provider(super::tls::CRYPTO_PROVIDER.clone()) | ||
.with_safe_default_protocol_versions()? | ||
.with_no_client_auth() | ||
.with_single_cert(cert, key)?; | ||
|
||
server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); | ||
|
||
let server_config = quinn::ServerConfig::with_crypto(Arc::new( | ||
QuicServerConfig::try_from(server_crypto).map_err(|_| TransportError::Internal)?, | ||
)); | ||
|
||
let endpoint = Endpoint::server(server_config, address) | ||
.map_err(|e| TransportError::Connection(address.to_string(), e))?; | ||
|
||
let (tx, mut rx) = mpsc::channel::<Result<Vec<u8>, TransportError>>(16); | ||
|
||
tokio::spawn(async move { | ||
while let Some(incoming_conn) = endpoint.accept().await { | ||
let tx = tx.clone(); | ||
|
||
tokio::spawn(async move { | ||
let conn = incoming_conn | ||
.await | ||
.map_err(|e| TransportError::Connection(address.to_string(), e.into()))?; | ||
|
||
let receive = conn.accept_uni().await; | ||
|
||
let mut receive = match receive { | ||
Err(quinn::ConnectionError::ApplicationClosed { .. }) => { | ||
return Ok(()); | ||
} | ||
Err(e) => { | ||
return Err(TransportError::Connection(address.to_string(), e.into())); | ||
} | ||
Ok(s) => s, | ||
}; | ||
|
||
let message = receive.read_to_end(8 * 1024).await.map_err(|_| { | ||
TransportError::InvalidMessageReceived(format!( | ||
"message from {address} is too long", | ||
)) | ||
}); | ||
|
||
tx.send(message) | ||
.await | ||
.map_err(|_| TransportError::Internal)?; | ||
|
||
Ok(()) | ||
}); | ||
} | ||
}); | ||
|
||
Ok(Box::pin(stream! { | ||
while let Some(item) = rx.recv().await { | ||
yield item; | ||
} | ||
})) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use futures::StreamExt; | ||
|
||
#[tokio::test] | ||
async fn test_quic_transport() { | ||
let url = Url::parse("quic://localhost:3737").unwrap(); | ||
let message = b"Hello, world!"; | ||
|
||
let mut incoming_stream = receive_messages(&url).await.unwrap(); | ||
|
||
send_message(message, &url).await.unwrap(); | ||
|
||
let received_message = incoming_stream.next().await.unwrap().unwrap(); | ||
|
||
assert_eq!(message, received_message.as_slice()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.