diff --git a/Cargo.lock b/Cargo.lock
index 10ebfe4..35fc6fe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3479,6 +3479,8 @@ dependencies = [
"ark-serialize",
"ark-std",
"async-compatibility-layer",
+ "async-h1",
+ "async-lock 3.3.0",
"async-std",
"async-trait",
"async-tungstenite 0.25.0",
diff --git a/Cargo.toml b/Cargo.toml
index 0c44c55..4a5fc68 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,6 +22,8 @@ required-features = ["testing"]
[dependencies]
anyhow = "1.0"
+async-h1 = "2.3"
+async-lock = "3.3"
async-std = { version = "1.12", features = ["attributes", "tokio1"] }
async-trait = "0.1.79"
clap = { version = "4.5", features = ["derive"] }
diff --git a/flake.lock b/flake.lock
index 0d05c42..a8aa7d8 100644
--- a/flake.lock
+++ b/flake.lock
@@ -137,11 +137,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
- "lastModified": 1711764554,
- "narHash": "sha256-I2/x/jFd7MAuIi3+kncIF0zJwhkFzxpi5XFdT2RLOF8=",
+ "lastModified": 1716949111,
+ "narHash": "sha256-ms3aD3Z2jKd1dk8qd0D/N7C8vFxn6z6LQ1G7cvNTVJ8=",
"owner": "oxalica",
"repo": "rust-overlay",
- "rev": "7cf3d11d06dcd12fb62ca2c039f3c5e25b53c5a7",
+ "rev": "2e7ccf572ce0f0547d4cf4426de4482936882d0e",
"type": "github"
},
"original": {
diff --git a/src/app.rs b/src/app.rs
index 67248f3..3a17e94 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -667,7 +667,7 @@ where
let message = format!("No API matches /{}", path[1..].join("/"));
return Ok(Self::top_level_error(req, StatusCode::NotFound, message));
};
- if module.versions.get(&version).is_none() {
+ if !module.versions.contains_key(&version) {
// This version is not supported, list suported versions.
return Ok(html! {
"Unsupported version v" (version) ". Supported versions are:"
@@ -1526,7 +1526,10 @@ mod test {
tracing::info!(?res, "<-");
assert_eq!(res.status(), expected_status);
let bytes = res.bytes().await.unwrap();
- S::deserialize(&bytes)
+ anyhow::Context::context(
+ S::deserialize(&bytes),
+ format!("failed to deserialize bytes {bytes:?}"),
+ )
}
#[tracing::instrument(skip(client))]
diff --git a/src/lib.rs b/src/lib.rs
index b9d562f..2ab09ab 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -282,6 +282,7 @@ pub mod api;
pub mod app;
pub mod error;
pub mod healthcheck;
+pub mod listener;
pub mod method;
pub mod metrics;
pub mod request;
diff --git a/src/listener.rs b/src/listener.rs
new file mode 100644
index 0000000..f3c0103
--- /dev/null
+++ b/src/listener.rs
@@ -0,0 +1,237 @@
+// Copyright (c) 2022 Espresso Systems (espressosys.com)
+// This file is part of the tide-disco library.
+
+// You should have received a copy of the MIT License
+// along with the tide-disco library. If not, see .
+
+use crate::StatusCode;
+use async_lock::Semaphore;
+use async_std::{
+ net::TcpListener,
+ sync::Arc,
+ task::{sleep, spawn},
+};
+use async_trait::async_trait;
+use derivative::Derivative;
+use futures::stream::StreamExt;
+use std::{
+ fmt::{self, Display, Formatter},
+ io::{self, ErrorKind},
+ net::SocketAddr,
+ time::Duration,
+};
+use tide::{
+ http,
+ listener::{ListenInfo, Listener, ToListener},
+ Server,
+};
+
+/// TCP listener which accepts only a limited number of connections at a time.
+///
+/// This listener is based on [`tide::listener::TcpListener`] and should match the semantics of that
+/// listener in every way, accept that when there are more simultaneous outstanding requests than
+/// the configured limit, excess requests will fail immediately with error code 429 (Too Many
+/// Requests).
+#[derive(Derivative)]
+#[derivative(Debug(bound = "State: Send + Sync + 'static"))]
+pub struct RateLimitListener {
+ addr: SocketAddr,
+ listener: Option,
+ server: Option>,
+ info: Option,
+ permit: Arc,
+}
+
+impl RateLimitListener {
+ /// Listen at the given address.
+ pub fn new(addr: SocketAddr, limit: usize) -> Self {
+ Self {
+ addr,
+ listener: None,
+ server: None,
+ info: None,
+ permit: Arc::new(Semaphore::new(limit)),
+ }
+ }
+
+ /// Listen at the given port on all interfaces.
+ pub fn with_port(port: u16, limit: usize) -> Self {
+ Self::new(([0, 0, 0, 0], port).into(), limit)
+ }
+}
+
+#[async_trait]
+impl Listener for RateLimitListener
+where
+ State: Clone + Send + Sync + 'static,
+{
+ async fn bind(&mut self, app: Server) -> io::Result<()> {
+ if self.server.is_some() {
+ return Err(io::Error::new(
+ ErrorKind::AlreadyExists,
+ "`bind` should only be called once",
+ ));
+ }
+ self.server = Some(app);
+ self.listener = Some(TcpListener::bind(&[self.addr][..]).await?);
+
+ // Format the listen information.
+ let conn_string = format!("{}", self);
+ let transport = "tcp".to_owned();
+ let tls = false;
+ self.info = Some(ListenInfo::new(conn_string, transport, tls));
+
+ Ok(())
+ }
+
+ async fn accept(&mut self) -> io::Result<()> {
+ let server = self.server.take().ok_or_else(|| {
+ io::Error::other("`Listener::bind` must be called before `Listener::accept`")
+ })?;
+ let listener = self.listener.take().ok_or_else(|| {
+ io::Error::other("`Listener::bind` must be called before `Listener::accept`")
+ })?;
+
+ let mut incoming = listener.incoming();
+ while let Some(stream) = incoming.next().await {
+ match stream {
+ Err(err) if is_transient_error(&err) => continue,
+ Err(err) => {
+ tracing::warn!(%err, "TCP error");
+ sleep(Duration::from_millis(500)).await;
+ continue;
+ }
+ Ok(stream) => {
+ let app = server.clone();
+ let permit = self.permit.clone();
+ spawn(async move {
+ let local_addr = stream.local_addr().ok();
+ let peer_addr = stream.peer_addr().ok();
+
+ let fut = async_h1::accept(stream, |mut req| async {
+ // Handle the request if we can get a permit.
+ if let Some(_guard) = permit.try_acquire() {
+ req.set_local_addr(local_addr);
+ req.set_peer_addr(peer_addr);
+ app.respond(req).await
+ } else {
+ // Otherwise, we are rate limited. Respond immediately with an
+ // error.
+ Ok(http::Response::new(StatusCode::TooManyRequests))
+ }
+ });
+
+ if let Err(error) = fut.await {
+ tracing::error!(%error, "HTTP error");
+ }
+ });
+ }
+ };
+ }
+ Ok(())
+ }
+
+ fn info(&self) -> Vec {
+ match &self.info {
+ Some(info) => vec![info.clone()],
+ None => vec![],
+ }
+ }
+}
+
+impl ToListener for RateLimitListener
+where
+ State: Clone + Send + Sync + 'static,
+{
+ type Listener = Self;
+
+ fn to_listener(self) -> io::Result {
+ Ok(self)
+ }
+}
+
+impl Display for RateLimitListener {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ match &self.listener {
+ Some(listener) => {
+ let addr = listener.local_addr().expect("Could not get local addr");
+ write!(f, "http://{}", addr)
+ }
+ None => write!(f, "http://{}", self.addr),
+ }
+ }
+}
+
+fn is_transient_error(e: &io::Error) -> bool {
+ matches!(
+ e.kind(),
+ ErrorKind::ConnectionRefused | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset
+ )
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::{
+ error::ServerError,
+ testing::{setup_test, Client},
+ App,
+ };
+ use futures::future::{try_join_all, FutureExt};
+ use portpicker::pick_unused_port;
+ use toml::toml;
+ use vbs::version::{StaticVersion, StaticVersionType};
+
+ type StaticVer01 = StaticVersion<0, 1>;
+
+ #[async_std::test]
+ async fn test_rate_limiting() {
+ setup_test();
+
+ let mut app = App::<_, ServerError>::with_state(());
+ let api_toml = toml! {
+ [route.test]
+ PATH = ["/test"]
+ METHOD = "GET"
+ };
+ {
+ let mut api = app
+ .module::("mod", api_toml)
+ .unwrap();
+ api.get("test", |_req, _state| {
+ async move {
+ // Make a really slow endpoint so we can have many simultaneous requests.
+ sleep(Duration::from_secs(30)).await;
+ Ok(())
+ }
+ .boxed()
+ })
+ .unwrap();
+ }
+
+ let limit = 10;
+ let port = pick_unused_port().unwrap();
+ spawn(app.serve(
+ RateLimitListener::with_port(port, limit),
+ StaticVer01::instance(),
+ ));
+ let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await;
+
+ // Start the maximum number of simultaneous requests.
+ let reqs = (0..limit)
+ .map(|_| spawn(client.get("mod/test").send()))
+ .collect::>();
+
+ // Wait a bit for those requests to get accepted.
+ sleep(Duration::from_secs(5)).await;
+
+ // The next request gets rate limited.
+ let res = client.get("mod/test").send().await.unwrap();
+ assert_eq!(StatusCode::TooManyRequests, res.status());
+
+ // The other requests eventually complete successfully.
+ for res in try_join_all(reqs).await.unwrap() {
+ assert_eq!(StatusCode::Ok, res.status());
+ }
+ }
+}
diff --git a/src/route.rs b/src/route.rs
index 06f8a26..0e24cc9 100644
--- a/src/route.rs
+++ b/src/route.rs
@@ -624,13 +624,13 @@ pub(crate) fn health_check_response(
///
/// Given a handler, this function can be used to derive a new, type-erased [HealthCheckHandler]
/// that takes only [RequestParams] and returns a generic [tide::Response].
-pub(crate) fn health_check_handler(
+pub(crate) fn health_check_handler(
handler: impl 'static + Send + Sync + Fn(&State) -> BoxFuture,
) -> HealthCheckHandler
where
State: 'static + Send + Sync,
H: 'static + HealthCheck,
- VER: 'static + Send + Sync,
+ VER: 'static + Send + Sync + StaticVersionType,
{
Box::new(move |req, state| {
let accept = req.accept().unwrap_or_else(|_| {