diff --git a/Cargo.lock b/Cargo.lock index 10ebfe4..35fc6fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3479,6 +3479,8 @@ dependencies = [ "ark-serialize", "ark-std", "async-compatibility-layer", + "async-h1", + "async-lock 3.3.0", "async-std", "async-trait", "async-tungstenite 0.25.0", diff --git a/Cargo.toml b/Cargo.toml index 0c44c55..4a5fc68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ required-features = ["testing"] [dependencies] anyhow = "1.0" +async-h1 = "2.3" +async-lock = "3.3" async-std = { version = "1.12", features = ["attributes", "tokio1"] } async-trait = "0.1.79" clap = { version = "4.5", features = ["derive"] } diff --git a/flake.lock b/flake.lock index 0d05c42..a8aa7d8 100644 --- a/flake.lock +++ b/flake.lock @@ -137,11 +137,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1711764554, - "narHash": "sha256-I2/x/jFd7MAuIi3+kncIF0zJwhkFzxpi5XFdT2RLOF8=", + "lastModified": 1716949111, + "narHash": "sha256-ms3aD3Z2jKd1dk8qd0D/N7C8vFxn6z6LQ1G7cvNTVJ8=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "7cf3d11d06dcd12fb62ca2c039f3c5e25b53c5a7", + "rev": "2e7ccf572ce0f0547d4cf4426de4482936882d0e", "type": "github" }, "original": { diff --git a/src/app.rs b/src/app.rs index 67248f3..3a17e94 100644 --- a/src/app.rs +++ b/src/app.rs @@ -667,7 +667,7 @@ where let message = format!("No API matches /{}", path[1..].join("/")); return Ok(Self::top_level_error(req, StatusCode::NotFound, message)); }; - if module.versions.get(&version).is_none() { + if !module.versions.contains_key(&version) { // This version is not supported, list suported versions. return Ok(html! { "Unsupported version v" (version) ". Supported versions are:" @@ -1526,7 +1526,10 @@ mod test { tracing::info!(?res, "<-"); assert_eq!(res.status(), expected_status); let bytes = res.bytes().await.unwrap(); - S::deserialize(&bytes) + anyhow::Context::context( + S::deserialize(&bytes), + format!("failed to deserialize bytes {bytes:?}"), + ) } #[tracing::instrument(skip(client))] diff --git a/src/lib.rs b/src/lib.rs index b9d562f..2ab09ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -282,6 +282,7 @@ pub mod api; pub mod app; pub mod error; pub mod healthcheck; +pub mod listener; pub mod method; pub mod metrics; pub mod request; diff --git a/src/listener.rs b/src/listener.rs new file mode 100644 index 0000000..f3c0103 --- /dev/null +++ b/src/listener.rs @@ -0,0 +1,237 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the tide-disco library. + +// You should have received a copy of the MIT License +// along with the tide-disco library. If not, see . + +use crate::StatusCode; +use async_lock::Semaphore; +use async_std::{ + net::TcpListener, + sync::Arc, + task::{sleep, spawn}, +}; +use async_trait::async_trait; +use derivative::Derivative; +use futures::stream::StreamExt; +use std::{ + fmt::{self, Display, Formatter}, + io::{self, ErrorKind}, + net::SocketAddr, + time::Duration, +}; +use tide::{ + http, + listener::{ListenInfo, Listener, ToListener}, + Server, +}; + +/// TCP listener which accepts only a limited number of connections at a time. +/// +/// This listener is based on [`tide::listener::TcpListener`] and should match the semantics of that +/// listener in every way, accept that when there are more simultaneous outstanding requests than +/// the configured limit, excess requests will fail immediately with error code 429 (Too Many +/// Requests). +#[derive(Derivative)] +#[derivative(Debug(bound = "State: Send + Sync + 'static"))] +pub struct RateLimitListener { + addr: SocketAddr, + listener: Option, + server: Option>, + info: Option, + permit: Arc, +} + +impl RateLimitListener { + /// Listen at the given address. + pub fn new(addr: SocketAddr, limit: usize) -> Self { + Self { + addr, + listener: None, + server: None, + info: None, + permit: Arc::new(Semaphore::new(limit)), + } + } + + /// Listen at the given port on all interfaces. + pub fn with_port(port: u16, limit: usize) -> Self { + Self::new(([0, 0, 0, 0], port).into(), limit) + } +} + +#[async_trait] +impl Listener for RateLimitListener +where + State: Clone + Send + Sync + 'static, +{ + async fn bind(&mut self, app: Server) -> io::Result<()> { + if self.server.is_some() { + return Err(io::Error::new( + ErrorKind::AlreadyExists, + "`bind` should only be called once", + )); + } + self.server = Some(app); + self.listener = Some(TcpListener::bind(&[self.addr][..]).await?); + + // Format the listen information. + let conn_string = format!("{}", self); + let transport = "tcp".to_owned(); + let tls = false; + self.info = Some(ListenInfo::new(conn_string, transport, tls)); + + Ok(()) + } + + async fn accept(&mut self) -> io::Result<()> { + let server = self.server.take().ok_or_else(|| { + io::Error::other("`Listener::bind` must be called before `Listener::accept`") + })?; + let listener = self.listener.take().ok_or_else(|| { + io::Error::other("`Listener::bind` must be called before `Listener::accept`") + })?; + + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + match stream { + Err(err) if is_transient_error(&err) => continue, + Err(err) => { + tracing::warn!(%err, "TCP error"); + sleep(Duration::from_millis(500)).await; + continue; + } + Ok(stream) => { + let app = server.clone(); + let permit = self.permit.clone(); + spawn(async move { + let local_addr = stream.local_addr().ok(); + let peer_addr = stream.peer_addr().ok(); + + let fut = async_h1::accept(stream, |mut req| async { + // Handle the request if we can get a permit. + if let Some(_guard) = permit.try_acquire() { + req.set_local_addr(local_addr); + req.set_peer_addr(peer_addr); + app.respond(req).await + } else { + // Otherwise, we are rate limited. Respond immediately with an + // error. + Ok(http::Response::new(StatusCode::TooManyRequests)) + } + }); + + if let Err(error) = fut.await { + tracing::error!(%error, "HTTP error"); + } + }); + } + }; + } + Ok(()) + } + + fn info(&self) -> Vec { + match &self.info { + Some(info) => vec![info.clone()], + None => vec![], + } + } +} + +impl ToListener for RateLimitListener +where + State: Clone + Send + Sync + 'static, +{ + type Listener = Self; + + fn to_listener(self) -> io::Result { + Ok(self) + } +} + +impl Display for RateLimitListener { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match &self.listener { + Some(listener) => { + let addr = listener.local_addr().expect("Could not get local addr"); + write!(f, "http://{}", addr) + } + None => write!(f, "http://{}", self.addr), + } + } +} + +fn is_transient_error(e: &io::Error) -> bool { + matches!( + e.kind(), + ErrorKind::ConnectionRefused | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset + ) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + error::ServerError, + testing::{setup_test, Client}, + App, + }; + use futures::future::{try_join_all, FutureExt}; + use portpicker::pick_unused_port; + use toml::toml; + use vbs::version::{StaticVersion, StaticVersionType}; + + type StaticVer01 = StaticVersion<0, 1>; + + #[async_std::test] + async fn test_rate_limiting() { + setup_test(); + + let mut app = App::<_, ServerError>::with_state(()); + let api_toml = toml! { + [route.test] + PATH = ["/test"] + METHOD = "GET" + }; + { + let mut api = app + .module::("mod", api_toml) + .unwrap(); + api.get("test", |_req, _state| { + async move { + // Make a really slow endpoint so we can have many simultaneous requests. + sleep(Duration::from_secs(30)).await; + Ok(()) + } + .boxed() + }) + .unwrap(); + } + + let limit = 10; + let port = pick_unused_port().unwrap(); + spawn(app.serve( + RateLimitListener::with_port(port, limit), + StaticVer01::instance(), + )); + let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await; + + // Start the maximum number of simultaneous requests. + let reqs = (0..limit) + .map(|_| spawn(client.get("mod/test").send())) + .collect::>(); + + // Wait a bit for those requests to get accepted. + sleep(Duration::from_secs(5)).await; + + // The next request gets rate limited. + let res = client.get("mod/test").send().await.unwrap(); + assert_eq!(StatusCode::TooManyRequests, res.status()); + + // The other requests eventually complete successfully. + for res in try_join_all(reqs).await.unwrap() { + assert_eq!(StatusCode::Ok, res.status()); + } + } +} diff --git a/src/route.rs b/src/route.rs index 06f8a26..0e24cc9 100644 --- a/src/route.rs +++ b/src/route.rs @@ -624,13 +624,13 @@ pub(crate) fn health_check_response( /// /// Given a handler, this function can be used to derive a new, type-erased [HealthCheckHandler] /// that takes only [RequestParams] and returns a generic [tide::Response]. -pub(crate) fn health_check_handler( +pub(crate) fn health_check_handler( handler: impl 'static + Send + Sync + Fn(&State) -> BoxFuture, ) -> HealthCheckHandler where State: 'static + Send + Sync, H: 'static + HealthCheck, - VER: 'static + Send + Sync, + VER: 'static + Send + Sync + StaticVersionType, { Box::new(move |req, state| { let accept = req.accept().unwrap_or_else(|_| {