Skip to content

Commit

Permalink
introduce opt-in callback for limit layer
Browse files Browse the repository at this point in the history
allows you to keep it compatible with inner
service Result

closes #324
  • Loading branch information
GlenDC committed Oct 14, 2024
1 parent 45b8c50 commit ced9c60
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 15 deletions.
37 changes: 37 additions & 0 deletions rama-core/src/layer/limit/into_response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::fmt;

pub(super) trait ErrorIntoResponse<Error>: Send + Sync + 'static {
type Response: Send + 'static;
type Error: Send + Sync + 'static;

fn error_into_response(&self, error: Error) -> Result<Self::Response, Self::Error>;
}

/// Wrapper around a user-specific transformer callback.
pub struct ErrorIntoResponseFn<F>(pub(super) F);

impl<F: fmt::Debug> fmt::Debug for ErrorIntoResponseFn<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ErrorIntoResponseFn").field(&self.0).finish()
}
}

impl<F: Clone> Clone for ErrorIntoResponseFn<F> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<Response, ErrorIn, ErrorOut, F> ErrorIntoResponse<ErrorIn> for ErrorIntoResponseFn<F>
where
F: Fn(ErrorIn) -> Result<Response, ErrorOut> + Send + Sync + 'static,
Response: Send + 'static,
ErrorOut: Send + Sync + 'static,
{
type Response = Response;
type Error = ErrorOut;

fn error_into_response(&self, error: ErrorIn) -> Result<Self::Response, ErrorOut> {
(self.0)(error)
}
}
37 changes: 29 additions & 8 deletions rama-core/src/layer/limit/layer.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
use std::fmt;

use super::{policy::UnlimitedPolicy, Limit};
use super::{into_response::ErrorIntoResponseFn, policy::UnlimitedPolicy, Limit};
use crate::Layer;

/// Limit requests based on a [`Policy`].
///
/// [`Policy`]: crate::layer::limit::Policy
pub struct LimitLayer<P> {
pub struct LimitLayer<P, F = ()> {
policy: P,
error_into_response: F,
}

impl<P> LimitLayer<P> {
/// Creates a new [`LimitLayer`] from a [`crate::layer::limit::Policy`].
pub const fn new(policy: P) -> Self {
LimitLayer { policy }
LimitLayer {
policy,
error_into_response: (),
}
}

/// Attach a function to this [`LimitLayer`] to allow you to turn the Policy error
/// into a Result fully compatible with the inner `Service` Result.
pub fn with_error_into_response_fn<F>(self, f: F) -> LimitLayer<P, ErrorIntoResponseFn<F>> {
LimitLayer {
policy: self.policy,
error_into_response: ErrorIntoResponseFn(f),
}
}
}

Expand All @@ -26,33 +39,41 @@ impl LimitLayer<UnlimitedPolicy> {
}
}

impl<P: fmt::Debug> std::fmt::Debug for LimitLayer<P> {
impl<P: fmt::Debug, F: fmt::Debug> std::fmt::Debug for LimitLayer<P, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LimitLayer")
.field("policy", &self.policy)
.field("self.error_into_response", &self.error_into_response)
.finish()
}
}

impl<P> Clone for LimitLayer<P>
impl<P, F> Clone for LimitLayer<P, F>
where
P: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
policy: self.policy.clone(),
error_into_response: self.error_into_response.clone(),
}
}
}

impl<T, P> Layer<T> for LimitLayer<P>
impl<T, P, F> Layer<T> for LimitLayer<P, F>
where
P: Clone,
F: Clone,
{
type Service = Limit<T, P>;
type Service = Limit<T, P, F>;

fn layer(&self, service: T) -> Self::Service {
let policy = self.policy.clone();
Limit::new(service, policy)
Limit {
inner: service,
policy,
error_into_response: self.error_into_response.clone(),
}
}
}
96 changes: 89 additions & 7 deletions rama-core/src/layer/limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::fmt;

use crate::error::BoxError;
use crate::{Context, Service};
use into_response::{ErrorIntoResponse, ErrorIntoResponseFn};
use rama_utils::macros::define_inner_service_accessors;

pub mod policy;
Expand All @@ -16,59 +17,80 @@ mod layer;
#[doc(inline)]
pub use layer::LimitLayer;

mod into_response;

/// Limit requests based on a [`Policy`].
///
/// [`Policy`]: crate::layer::limit::Policy
pub struct Limit<S, P> {
pub struct Limit<S, P, F = ()> {
inner: S,
policy: P,
error_into_response: F,
}

impl<S, P> Limit<S, P> {
impl<S, P> Limit<S, P, ()> {
/// Creates a new [`Limit`] from a limit [`Policy`],
/// wrapping the given [`Service`].
pub const fn new(inner: S, policy: P) -> Self {
Limit { inner, policy }
Limit {
inner,
policy,
error_into_response: (),
}
}

/// Attach a function to this [`Limit`] to allow you to turn the Policy error
/// into a Result fully compatible with the inner `Service` Result.
pub fn with_error_into_response_fn<F>(self, f: F) -> Limit<S, P, ErrorIntoResponseFn<F>> {
Limit {
inner: self.inner,
policy: self.policy,
error_into_response: ErrorIntoResponseFn(f),
}
}

define_inner_service_accessors!();
}

impl<T> Limit<T, UnlimitedPolicy> {
impl<T> Limit<T, UnlimitedPolicy, ()> {
/// Creates a new [`Limit`] with an unlimited policy.
///
/// Meaning that all requests are allowed to proceed.
pub const fn unlimited(inner: T) -> Self {
Limit {
inner,
policy: UnlimitedPolicy,
error_into_response: (),
}
}
}

impl<T: fmt::Debug, P: fmt::Debug> fmt::Debug for Limit<T, P> {
impl<T: fmt::Debug, P: fmt::Debug, F: fmt::Debug> fmt::Debug for Limit<T, P, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Limit")
.field("inner", &self.inner)
.field("policy", &self.policy)
.field("error_into_response", &self.error_into_response)
.finish()
}
}

impl<T, P> Clone for Limit<T, P>
impl<T, P, F> Clone for Limit<T, P, F>
where
T: Clone,
P: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Limit {
inner: self.inner.clone(),
policy: self.policy.clone(),
error_into_response: self.error_into_response.clone(),
}
}
}

impl<T, P, State, Request> Service<State, Request> for Limit<T, P>
impl<T, P, State, Request> Service<State, Request> for Limit<T, P, ()>
where
T: Service<State, Request, Error: Into<BoxError>>,
P: policy::Policy<State, Request, Error: Into<BoxError>>,
Expand Down Expand Up @@ -100,6 +122,47 @@ where
}
}

impl<T, P, F, State, Request, FnResponse, FnError> Service<State, Request>
for Limit<T, P, ErrorIntoResponseFn<F>>
where
T: Service<State, Request>,
P: policy::Policy<State, Request>,
F: Fn(P::Error) -> Result<FnResponse, FnError> + Send + Sync + 'static,
FnResponse: Into<T::Response> + Send + 'static,
FnError: Into<T::Error> + Send + Sync + 'static,
Request: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = T::Response;
type Error = T::Error;

async fn serve(
&self,
mut ctx: Context<State>,
mut request: Request,
) -> Result<Self::Response, Self::Error> {
loop {
let result = self.policy.check(ctx, request).await;
ctx = result.ctx;
request = result.request;

match result.output {
policy::PolicyOutput::Ready(guard) => {
let _ = guard;
return self.inner.serve(ctx, request).await;
}
policy::PolicyOutput::Abort(err) => {
return match self.error_into_response.error_into_response(err) {
Ok(ok) => Ok(ok.into()),
Err(err) => Err(err.into()),
}
}
policy::PolicyOutput::Retry => (),
}
}
}
}

#[cfg(test)]
mod tests {
use super::policy::ConcurrentPolicy;
Expand Down Expand Up @@ -139,6 +202,25 @@ mod tests {
}
}

#[tokio::test]
async fn test_with_error_into_response_fn() {
async fn handle_request<State, Request>(
_ctx: Context<State>,
_req: Request,
) -> Result<&'static str, Infallible> {
Ok("good")
}

let layer: LimitLayer<ConcurrentPolicy<_, _>, _> =
LimitLayer::new(ConcurrentPolicy::max(0))
.with_error_into_response_fn(|_| Ok::<_, Infallible>("bad"));

let service = layer.layer(service_fn(handle_request));

let resp = service.serve(Context::default(), "Hello").await.unwrap();
assert_eq!("bad", resp);
}

#[tokio::test]
async fn test_zero_limit() {
async fn handle_request<State, Request>(
Expand Down

0 comments on commit ced9c60

Please sign in to comment.