diff --git a/rama-core/src/layer/limit/into_response.rs b/rama-core/src/layer/limit/into_response.rs new file mode 100644 index 00000000..7568275a --- /dev/null +++ b/rama-core/src/layer/limit/into_response.rs @@ -0,0 +1,37 @@ +use std::fmt; + +pub(super) trait ErrorIntoResponse: Send + Sync + 'static { + type Response: Send + 'static; + type Error: Send + Sync + 'static; + + fn error_into_response(&self, error: Error) -> Result; +} + +/// Wrapper around a user-specific transformer callback. +pub struct ErrorIntoResponseFn(pub(super) F); + +impl fmt::Debug for ErrorIntoResponseFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ErrorIntoResponseFn").field(&self.0).finish() + } +} + +impl Clone for ErrorIntoResponseFn { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl ErrorIntoResponse for ErrorIntoResponseFn +where + F: Fn(ErrorIn) -> Result + Send + Sync + 'static, + Response: Send + 'static, + ErrorOut: Send + Sync + 'static, +{ + type Response = Response; + type Error = ErrorOut; + + fn error_into_response(&self, error: ErrorIn) -> Result { + (self.0)(error) + } +} diff --git a/rama-core/src/layer/limit/layer.rs b/rama-core/src/layer/limit/layer.rs index dffe651d..56fd91b7 100644 --- a/rama-core/src/layer/limit/layer.rs +++ b/rama-core/src/layer/limit/layer.rs @@ -1,19 +1,32 @@ use std::fmt; -use super::{policy::UnlimitedPolicy, Limit}; +use super::{into_response::ErrorIntoResponseFn, policy::UnlimitedPolicy, Limit}; use crate::Layer; /// Limit requests based on a [`Policy`]. /// /// [`Policy`]: crate::layer::limit::Policy -pub struct LimitLayer

{ +pub struct LimitLayer { policy: P, + error_into_response: F, } impl

LimitLayer

{ /// Creates a new [`LimitLayer`] from a [`crate::layer::limit::Policy`]. pub const fn new(policy: P) -> Self { - LimitLayer { policy } + LimitLayer { + policy, + error_into_response: (), + } + } + + /// Attach a function to this [`LimitLayer`] to allow you to turn the Policy error + /// into a Result fully compatible with the inner `Service` Result. + pub fn with_error_into_response_fn(self, f: F) -> LimitLayer> { + LimitLayer { + policy: self.policy, + error_into_response: ErrorIntoResponseFn(f), + } } } @@ -26,33 +39,41 @@ impl LimitLayer { } } -impl std::fmt::Debug for LimitLayer

{ +impl std::fmt::Debug for LimitLayer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LimitLayer") .field("policy", &self.policy) + .field("self.error_into_response", &self.error_into_response) .finish() } } -impl

Clone for LimitLayer

+impl Clone for LimitLayer where P: Clone, + F: Clone, { fn clone(&self) -> Self { Self { policy: self.policy.clone(), + error_into_response: self.error_into_response.clone(), } } } -impl Layer for LimitLayer

+impl Layer for LimitLayer where P: Clone, + F: Clone, { - type Service = Limit; + type Service = Limit; fn layer(&self, service: T) -> Self::Service { let policy = self.policy.clone(); - Limit::new(service, policy) + Limit { + inner: service, + policy, + error_into_response: self.error_into_response.clone(), + } } } diff --git a/rama-core/src/layer/limit/mod.rs b/rama-core/src/layer/limit/mod.rs index 6c1a21d2..29573965 100644 --- a/rama-core/src/layer/limit/mod.rs +++ b/rama-core/src/layer/limit/mod.rs @@ -6,6 +6,7 @@ use std::fmt; use crate::error::BoxError; use crate::{Context, Service}; +use into_response::{ErrorIntoResponse, ErrorIntoResponseFn}; use rama_utils::macros::define_inner_service_accessors; pub mod policy; @@ -16,25 +17,42 @@ mod layer; #[doc(inline)] pub use layer::LimitLayer; +mod into_response; + /// Limit requests based on a [`Policy`]. /// /// [`Policy`]: crate::layer::limit::Policy -pub struct Limit { +pub struct Limit { inner: S, policy: P, + error_into_response: F, } -impl Limit { +impl Limit { /// Creates a new [`Limit`] from a limit [`Policy`], /// wrapping the given [`Service`]. pub const fn new(inner: S, policy: P) -> Self { - Limit { inner, policy } + Limit { + inner, + policy, + error_into_response: (), + } + } + + /// Attach a function to this [`Limit`] to allow you to turn the Policy error + /// into a Result fully compatible with the inner `Service` Result. + pub fn with_error_into_response_fn(self, f: F) -> Limit> { + Limit { + inner: self.inner, + policy: self.policy, + error_into_response: ErrorIntoResponseFn(f), + } } define_inner_service_accessors!(); } -impl Limit { +impl Limit { /// Creates a new [`Limit`] with an unlimited policy. /// /// Meaning that all requests are allowed to proceed. @@ -42,33 +60,37 @@ impl Limit { Limit { inner, policy: UnlimitedPolicy, + error_into_response: (), } } } -impl fmt::Debug for Limit { +impl fmt::Debug for Limit { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Limit") .field("inner", &self.inner) .field("policy", &self.policy) + .field("error_into_response", &self.error_into_response) .finish() } } -impl Clone for Limit +impl Clone for Limit where T: Clone, P: Clone, + F: Clone, { fn clone(&self) -> Self { Limit { inner: self.inner.clone(), policy: self.policy.clone(), + error_into_response: self.error_into_response.clone(), } } } -impl Service for Limit +impl Service for Limit where T: Service>, P: policy::Policy>, @@ -100,6 +122,47 @@ where } } +impl Service + for Limit> +where + T: Service, + P: policy::Policy, + F: Fn(P::Error) -> Result + Send + Sync + 'static, + FnResponse: Into + Send + 'static, + FnError: Into + Send + Sync + 'static, + Request: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, +{ + type Response = T::Response; + type Error = T::Error; + + async fn serve( + &self, + mut ctx: Context, + mut request: Request, + ) -> Result { + loop { + let result = self.policy.check(ctx, request).await; + ctx = result.ctx; + request = result.request; + + match result.output { + policy::PolicyOutput::Ready(guard) => { + let _ = guard; + return self.inner.serve(ctx, request).await; + } + policy::PolicyOutput::Abort(err) => { + return match self.error_into_response.error_into_response(err) { + Ok(ok) => Ok(ok.into()), + Err(err) => Err(err.into()), + } + } + policy::PolicyOutput::Retry => (), + } + } + } +} + #[cfg(test)] mod tests { use super::policy::ConcurrentPolicy; @@ -139,6 +202,25 @@ mod tests { } } + #[tokio::test] + async fn test_with_error_into_response_fn() { + async fn handle_request( + _ctx: Context, + _req: Request, + ) -> Result<&'static str, Infallible> { + Ok("good") + } + + let layer: LimitLayer, _> = + LimitLayer::new(ConcurrentPolicy::max(0)) + .with_error_into_response_fn(|_| Ok::<_, Infallible>("bad")); + + let service = layer.layer(service_fn(handle_request)); + + let resp = service.serve(Context::default(), "Hello").await.unwrap(); + assert_eq!("bad", resp); + } + #[tokio::test] async fn test_zero_limit() { async fn handle_request(