diff --git a/crates/kernel/src/backend_service.rs b/crates/kernel/src/backend_service.rs new file mode 100644 index 00000000..04f2bd5c --- /dev/null +++ b/crates/kernel/src/backend_service.rs @@ -0,0 +1,143 @@ +use std::convert::Infallible; +use std::path::Path; +use std::sync::Arc; + +use futures_util::future::BoxFuture; +use futures_util::Future; +use hyper::{header::UPGRADE, Request, Response, StatusCode}; +use tracing::instrument; + +use crate::backend_service::http_client_service::get_client; +use crate::helper_layers::map_future::MapFuture; +use crate::utils::x_forwarded_for; +use crate::BoxError; +use crate::SgBody; +use crate::SgResponse; +use crate::SgResponseExt; + +pub mod echo; +pub mod http_client_service; +pub mod static_file_service; +pub mod ws_client_service; +pub(crate) const FILE_SCHEMA: &str = "file"; +pub trait CloneHyperService: hyper::service::Service { + fn clone_box(&self) -> Box + Send + Sync>; +} + +impl CloneHyperService for T +where + T: hyper::service::Service + Send + Sync + Clone + 'static, +{ + fn clone_box(&self) -> Box + Send + Sync> { + Box::new(self.clone()) + } +} +pub struct ArcHyperService { + pub boxed: Arc< + dyn CloneHyperService, Response = Response, Error = Infallible, Future = BoxFuture<'static, Result, Infallible>>> + Send + Sync, + >, +} + +impl std::fmt::Debug for ArcHyperService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArcHyperService").finish() + } +} + +impl Clone for ArcHyperService { + fn clone(&self) -> Self { + Self { boxed: self.boxed.clone() } + } +} + +impl ArcHyperService { + pub fn new(service: T) -> Self + where + T: Clone + CloneHyperService, Response = Response, Error = Infallible> + Send + Sync + 'static, + T::Future: Future, Infallible>> + 'static + Send, + { + let map_fut = MapFuture::new(service, |fut| Box::pin(fut) as _); + Self { boxed: Arc::new(map_fut) } + } +} + +impl hyper::service::Service> for ArcHyperService { + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn call(&self, req: Request) -> Self::Future { + Box::pin(self.boxed.call(req)) + } +} + +/// Http backend service +/// +/// This function could be a bottom layer of a http router, it will handle http and websocket request. +/// +/// This can handle both websocket connection and http request. +/// +/// # Errors +/// 1. Fail to collect body chunks +/// 2. Fail to upgrade +pub async fn http_backend_service_inner(mut req: Request) -> Result { + tracing::trace!(elapsed = ?req.extensions().get::().map(crate::extension::EnterTime::elapsed), "start a backend request"); + x_forwarded_for(&mut req)?; + if req.uri().scheme_str() == Some(FILE_SCHEMA) { + return Ok(static_file_service::static_file_service(req, Path::new("./")).await); + } + let mut client = get_client(); + let mut response = if req.headers().get(UPGRADE).is_some_and(|upgrade| upgrade.as_bytes().eq_ignore_ascii_case(b"websocket")) { + // dump request + let (part, body) = req.into_parts(); + let body = body.dump().await?; + let req = Request::from_parts(part, body); + + // forward request + let resp = client.request(req.clone()).await; + + // dump response + let (part, body) = resp.into_parts(); + let body = body.dump().await?; + let resp = Response::from_parts(part, body); + + let req_for_upgrade = req.clone(); + let resp_for_upgrade = resp.clone(); + + // create forward task + tokio::task::spawn(async move { + // update both side + let (s, c) = futures_util::join!(hyper::upgrade::on(req_for_upgrade), hyper::upgrade::on(resp_for_upgrade)); + let upgrade_as_server = s?; + let upgrade_as_client = c?; + // start a websocket forward + ws_client_service::tcp_transfer(upgrade_as_server, upgrade_as_client).await?; + >::Ok(()) + }); + tracing::trace!(elapsed = ?resp.extensions().get::().map(crate::extension::EnterTime::elapsed), "finish backend websocket forward"); + // return response to client + resp + } else { + let resp = client.request(req).await; + tracing::trace!(elapsed = ?resp.extensions().get::().map(crate::extension::EnterTime::elapsed), "finish backend request"); + resp + }; + response.extensions_mut().insert(unsafe { crate::extension::FromBackend::new() }); + Ok(response) +} + +#[instrument] +pub async fn http_backend_service(req: Request) -> Result, Infallible> { + match http_backend_service_inner(req).await { + Ok(resp) => Ok(resp), + Err(err) => Ok(Response::with_code_message(StatusCode::BAD_GATEWAY, format!("[Sg.Client] Client error: {err}"))), + } +} + +pub fn get_http_backend_service() -> ArcHyperService { + ArcHyperService::new(hyper::service::service_fn(http_backend_service)) +} + +pub fn get_echo_service() -> ArcHyperService { + ArcHyperService::new(hyper::service::service_fn(echo::echo)) +} diff --git a/crates/kernel/src/service/echo.rs b/crates/kernel/src/backend_service/echo.rs similarity index 100% rename from crates/kernel/src/service/echo.rs rename to crates/kernel/src/backend_service/echo.rs diff --git a/crates/kernel/src/service/http_client_service.rs b/crates/kernel/src/backend_service/http_client_service.rs similarity index 90% rename from crates/kernel/src/service/http_client_service.rs rename to crates/kernel/src/backend_service/http_client_service.rs index 20eb0fd3..474cdf74 100644 --- a/crates/kernel/src/service/http_client_service.rs +++ b/crates/kernel/src/backend_service/http_client_service.rs @@ -73,19 +73,19 @@ fn get_rustls_config_dangerous() -> rustls::ClientConfig { config } -pub fn get_client() -> SgHttpClient { +pub fn get_client() -> HttpClient { ClientRepo::global().get_default() } pub struct ClientRepo { - default: SgHttpClient, - repo: Mutex>, + default: HttpClient, + repo: Mutex>, } impl Default for ClientRepo { fn default() -> Self { let config = get_rustls_config_dangerous(); - let default = SgHttpClient::new(config); + let default = HttpClient::new(config); Self { default, repo: Default::default(), @@ -95,19 +95,19 @@ impl Default for ClientRepo { static mut GLOBAL: OnceLock = OnceLock::new(); impl ClientRepo { - pub fn get(&self, code: &str) -> Option { + pub fn get(&self, code: &str) -> Option { self.repo.lock().expect("failed to lock client repo").get(code).cloned() } - pub fn get_or_default(&self, code: &str) -> SgHttpClient { + pub fn get_or_default(&self, code: &str) -> HttpClient { self.get(code).unwrap_or_else(|| self.default.clone()) } - pub fn get_default(&self) -> SgHttpClient { + pub fn get_default(&self) -> HttpClient { self.default.clone() } - pub fn register(&self, code: &str, client: SgHttpClient) { + pub fn register(&self, code: &str, client: HttpClient) { self.repo.lock().expect("failed to lock client repo").insert(code.to_string(), client); } - pub fn set_default(&mut self, client: SgHttpClient) { + pub fn set_default(&mut self, client: HttpClient) { self.default = client; } pub fn global() -> &'static Self { @@ -116,7 +116,7 @@ impl ClientRepo { /// # Safety /// This function is not thread safe, it should be called before any other thread is spawned. - pub unsafe fn set_global_default(client: SgHttpClient) { + pub unsafe fn set_global_default(client: HttpClient) { GLOBAL.get_or_init(Default::default); GLOBAL.get_mut().expect("global not set").set_default(client); } @@ -127,19 +127,19 @@ pub struct SgHttpClientConfig { } #[derive(Debug, Clone)] -pub struct SgHttpClient { +pub struct HttpClient { inner: Client, SgBody>, } -impl Default for SgHttpClient { +impl Default for HttpClient { fn default() -> Self { Self::new(rustls::ClientConfig::builder().with_native_roots().expect("failed to init rustls config").with_no_client_auth()) } } -impl SgHttpClient { +impl HttpClient { pub fn new(tls_config: rustls::ClientConfig) -> Self { - SgHttpClient { + HttpClient { inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().enable_http2().build()), } } diff --git a/crates/kernel/src/service/static_file_service.rs b/crates/kernel/src/backend_service/static_file_service.rs similarity index 89% rename from crates/kernel/src/service/static_file_service.rs rename to crates/kernel/src/backend_service/static_file_service.rs index b6244517..51b5e1a3 100644 --- a/crates/kernel/src/service/static_file_service.rs +++ b/crates/kernel/src/backend_service/static_file_service.rs @@ -6,6 +6,7 @@ use hyper::{ HeaderMap, Response, StatusCode, }; use tokio::io::AsyncReadExt; +use tracing::{instrument, trace}; use crate::{extension::Reflect, SgBody, SgRequest, SgResponse}; @@ -38,14 +39,23 @@ pub fn cache_policy(metadata: &Metadata) -> bool { size < (1 << 20) } -/// +#[instrument()] pub async fn static_file_service(mut request: SgRequest, dir: &Path) -> SgResponse { - // request.headers().get() let mut response = Response::builder().body(SgBody::empty()).expect("failed to build response"); if let Some(reflect) = request.extensions_mut().remove::() { *response.extensions_mut() = reflect.into_inner(); } - let path = dir.join(request.uri().path()).canonicalize().unwrap_or(dir.to_path_buf()); + let Ok(dir) = dir.canonicalize() else { + *response.status_mut() = StatusCode::FORBIDDEN; + return response; + }; + + let Ok(path) = dir.join(request.uri().path().trim_start_matches('/')).canonicalize() else { + *response.status_mut() = StatusCode::FORBIDDEN; + return response; + }; + + trace!("static file path: {:?}", path); if !path.starts_with(dir) { *response.status_mut() = StatusCode::FORBIDDEN; return response; diff --git a/crates/kernel/src/service/ws_client_service.rs b/crates/kernel/src/backend_service/ws_client_service.rs similarity index 100% rename from crates/kernel/src/service/ws_client_service.rs rename to crates/kernel/src/backend_service/ws_client_service.rs diff --git a/crates/kernel/src/extension/matched.rs b/crates/kernel/src/extension/matched.rs index 09cb8c9a..e2ac2189 100644 --- a/crates/kernel/src/extension/matched.rs +++ b/crates/kernel/src/extension/matched.rs @@ -1,4 +1,4 @@ -use crate::{helper_layers::route::Router, layers::http_route::match_request::HttpRouteMatch}; +use crate::{helper_layers::route::Router, service::http_route::match_request::HttpRouteMatch}; use std::{ops::Deref, sync::Arc}; #[derive(Debug, Clone)] diff --git a/crates/kernel/src/header.rs b/crates/kernel/src/header.rs deleted file mode 100644 index b9e35b3c..00000000 --- a/crates/kernel/src/header.rs +++ /dev/null @@ -1,4 +0,0 @@ -#![allow(clippy::declare_interior_mutable_const)] -use hyper::header::HeaderName; - -pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); diff --git a/crates/kernel/src/helper_layers/function.rs b/crates/kernel/src/helper_layers/function.rs index cb012489..5e916e24 100644 --- a/crates/kernel/src/helper_layers/function.rs +++ b/crates/kernel/src/helper_layers/function.rs @@ -165,14 +165,14 @@ mod test { resp } } - use crate::SgBoxLayer; + use crate::BoxLayer; use super::*; #[test] fn test_fn_layer() { let status_message = Arc::new(>::default()); - let boxed_layer = SgBoxLayer::new(FnLayer::new(MyPlugin::default())); - let boxed_layer2 = SgBoxLayer::new(FnLayer::new_closure(move |req, inner| { + let boxed_layer = BoxLayer::new(FnLayer::new(MyPlugin::default())); + let boxed_layer2 = BoxLayer::new(FnLayer::new_closure(move |req, inner| { let host = req.headers().get("host"); if let Some(Ok(host)) = host.map(HeaderValue::to_str) { println!("{host}"); diff --git a/crates/kernel/src/helper_layers/random_pick.rs b/crates/kernel/src/helper_layers/random_pick.rs index ad38c90b..2d6b09f0 100644 --- a/crates/kernel/src/helper_layers/random_pick.rs +++ b/crates/kernel/src/helper_layers/random_pick.rs @@ -35,9 +35,13 @@ where type Error = S::Error; type Future = S::Future; + #[allow(clippy::indexing_slicing)] fn call(&self, req: R) -> Self::Future { - let index = self.picker.sample(&mut rand::thread_rng()); - #[allow(clippy::indexing_slicing)] - self.services[index].call(req) + if self.services.len() == 1 { + self.services[0].call(req) + } else { + let index = self.picker.sample(&mut rand::thread_rng()); + self.services[index].call(req) + } } } diff --git a/crates/kernel/src/helper_layers/route.rs b/crates/kernel/src/helper_layers/route.rs index c74c5a5e..713bd124 100644 --- a/crates/kernel/src/helper_layers/route.rs +++ b/crates/kernel/src/helper_layers/route.rs @@ -12,7 +12,7 @@ pub trait Router: Clone { } #[derive(Debug, Clone)] -pub struct Route +pub struct RouterService where R: Router, { @@ -21,7 +21,7 @@ where router: R, } -impl Route +impl RouterService where R: Router, S: Index, @@ -31,7 +31,7 @@ where } } -impl hyper::service::Service> for Route +impl hyper::service::Service> for RouterService where R: Router + Send + Sync + 'static, R::Index: Send + Sync + 'static + Clone, diff --git a/crates/kernel/src/helper_layers/timeout.rs b/crates/kernel/src/helper_layers/timeout.rs index 2342beed..6583372d 100644 --- a/crates/kernel/src/helper_layers/timeout.rs +++ b/crates/kernel/src/helper_layers/timeout.rs @@ -1,17 +1,14 @@ -use std::{ - convert::Infallible, - time::{Duration, Instant}, -}; +use std::{convert::Infallible, time::Duration}; +use crate::SgBody; use futures_util::Future; use hyper::{Request, Response}; +use tokio::time::Sleep; use tower_layer::Layer; - -use crate::SgBody; #[derive(Clone)] pub struct TimeoutLayer { - /// timeout duration, none value means no timeout - pub timeout: Option, + /// timeout duration + pub timeout: Duration, pub timeout_response: hyper::body::Bytes, } @@ -30,24 +27,24 @@ impl Layer for TimeoutLayer { #[derive(Clone)] pub struct Timeout { inner: S, - timeout: Option, + timeout: Duration, timeout_response: hyper::body::Bytes, } impl TimeoutLayer { - pub fn new(timeout: Option) -> Self { + pub fn new(timeout: Duration) -> Self { Self { timeout, timeout_response: hyper::body::Bytes::default(), } } - pub fn set_timeout(&mut self, timeout: Option) { + pub fn set_timeout(&mut self, timeout: Duration) { self.timeout = timeout; } } impl Timeout { - pub fn new(timeout: Option, timeout_response: hyper::body::Bytes, inner: S) -> Self { + pub fn new(timeout: Duration, timeout_response: hyper::body::Bytes, inner: S) -> Self { Self { inner, timeout, timeout_response } } } @@ -66,7 +63,7 @@ where fn call(&self, req: Request) -> Self::Future { TimeoutFuture { inner: self.inner.call(req), - timeout_at: self.timeout.map(|d| Instant::now() + d), + timeout: tokio::time::sleep(self.timeout), timeout_response: self.timeout_response.clone(), } } @@ -76,7 +73,8 @@ pin_project_lite::pin_project! { pub struct TimeoutFuture { #[pin] inner: F, - timeout_at: Option, + #[pin] + timeout: Sleep, timeout_response: hyper::body::Bytes, } } @@ -89,11 +87,9 @@ where fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { let this = self.project(); - if let Some(timeout_at) = this.timeout_at { - if Instant::now() >= *timeout_at { - let response = Response::builder().status(hyper::StatusCode::GATEWAY_TIMEOUT).body(SgBody::full(this.timeout_response.clone())).expect("invalid response"); - return std::task::Poll::Ready(Ok(response)); - } + if this.timeout.poll(cx).is_ready() { + let response = Response::builder().status(hyper::StatusCode::GATEWAY_TIMEOUT).body(SgBody::full(this.timeout_response.clone())).expect("invalid response"); + return std::task::Poll::Ready(Ok(response)); } this.inner.poll(cx) } diff --git a/crates/kernel/src/layers.rs b/crates/kernel/src/layers.rs deleted file mode 100644 index 57743736..00000000 --- a/crates/kernel/src/layers.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod http_route; - -pub mod gateway; diff --git a/crates/kernel/src/lib.rs b/crates/kernel/src/lib.rs index 4c2385ea..fef952fd 100644 --- a/crates/kernel/src/lib.rs +++ b/crates/kernel/src/lib.rs @@ -1,20 +1,19 @@ #![deny(clippy::unwrap_used, clippy::dbg_macro, clippy::unimplemented, clippy::todo)] #![warn(clippy::missing_errors_doc, clippy::indexing_slicing)] // pub mod config; +pub mod backend_service; pub mod body; pub mod extension; pub mod extractor; -pub mod header; pub mod helper_layers; -pub mod layers; pub mod listener; pub mod service; pub mod utils; +pub use backend_service::ArcHyperService; pub use body::SgBody; use extension::Reflect; pub use extractor::Extractor; -pub use service::ArcHyperService; use std::{convert::Infallible, fmt}; pub use tower_layer::Layer; @@ -97,11 +96,11 @@ impl SgResponseExt for Response { pub type ReqOrResp = Result, Response>; -pub struct SgBoxLayer { +pub struct BoxLayer { boxed: Box + Send + Sync + 'static>, } -impl SgBoxLayer { +impl BoxLayer { /// Create a new [`SgBoxLayer`]. pub fn new(inner_layer: L) -> Self where @@ -122,7 +121,7 @@ impl SgBoxLayer { } } -impl Layer for SgBoxLayer +impl Layer for BoxLayer where S: Clone + hyper::service::Service, Response = Response, Error = Infallible> + Send + Sync + 'static, >>::Future: std::marker::Send, @@ -134,7 +133,7 @@ where } } -impl fmt::Debug for SgBoxLayer { +impl fmt::Debug for BoxLayer { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.debug_struct("BoxLayer").finish() } diff --git a/crates/kernel/src/listener.rs b/crates/kernel/src/listener.rs index cd69cd25..cc891f99 100644 --- a/crates/kernel/src/listener.rs +++ b/crates/kernel/src/listener.rs @@ -116,30 +116,34 @@ where S: hyper::service::Service, Error = Infallible, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, { - #[instrument(skip(stream, service, tls_cfg, conn_builder, cancel_token))] + #[instrument(skip(stream, service, tls_cfg, conn_builder))] async fn accept( conn_builder: hyper_util::server::conn::auto::Builder, stream: TcpStream, peer_addr: SocketAddr, tls_cfg: Option>, - #[allow(unused_variables)] cancel_token: CancellationToken, service: S, - ) -> Result<(), BoxError> { + ) { tracing::debug!("[Sg.Listen] Accepted connection"); let service = HyperServiceAdapter::new(service, peer_addr); - if let Some(tls_cfg) = tls_cfg { + let conn_result = if let Some(tls_cfg) = tls_cfg { let connector = tokio_rustls::TlsAcceptor::from(tls_cfg); - let accepted = connector.accept(stream).await?; + let Ok(accepted) = connector.accept(stream).await.inspect_err(|e| tracing::warn!("[Sg.Listen] Tls connect error: {:?}", e)) else { + return; + }; let io = TokioIo::new(accepted); let conn = conn_builder.serve_connection_with_upgrades(io, service); - conn.await?; + conn.await } else { let io = TokioIo::new(stream); let conn = conn_builder.serve_connection_with_upgrades(io, service); - conn.await?; + conn.await + }; + if let Err(e) = conn_result { + tracing::warn!("[Sg.Listen] Connection closed with error {e}") + } else { + tracing::debug!("[Sg.Listen] Connection closed"); } - tracing::debug!("[Sg.Listen] Connection closed"); - Ok(()) } #[instrument()] pub async fn listen(self) -> Result<(), BoxError> { @@ -148,28 +152,22 @@ where let cancel_token = self.cancel_token; tracing::debug!("[Sg.Listen] start listening..."); loop { - tokio::select! { + let accepted = tokio::select! { () = cancel_token.cancelled() => { tracing::warn!("[Sg.Listen] cancelled"); return Ok(()); }, - accepted = listener.accept() => { - match accepted { - Ok((stream, peer_addr)) => { - let tls_cfg = self.tls_cfg.clone(); - let service = self.service.clone(); - let builder = self.conn_builder.clone(); - let cancel_token = cancel_token.clone(); - tokio::spawn(async move { - if let Err(e) = Self::accept(builder, stream, peer_addr, tls_cfg, cancel_token, service).await { - tracing::warn!("[Sg.Listen] Accept stream error: {:?}", e); - } - }); - }, - Err(e) => { - tracing::warn!("[Sg.Listen] Accept tcp connection error: {:?}", e); - } - } + accepted = listener.accept() => accepted + }; + match accepted { + Ok((stream, peer_addr)) => { + let tls_cfg = self.tls_cfg.clone(); + let service = self.service.clone(); + let builder = self.conn_builder.clone(); + tokio::spawn(Self::accept(builder, stream, peer_addr, tls_cfg, service)); + } + Err(e) => { + tracing::warn!("[Sg.Listen] Accept tcp connection error: {:?}", e); } } } diff --git a/crates/kernel/src/service.rs b/crates/kernel/src/service.rs index 9e2626e7..57743736 100644 --- a/crates/kernel/src/service.rs +++ b/crates/kernel/src/service.rs @@ -1,147 +1,3 @@ -use std::convert::Infallible; -use std::path::Path; -use std::sync::Arc; +pub mod http_route; -use futures_util::future::BoxFuture; -use futures_util::Future; -use hyper::{header::UPGRADE, Request, Response, StatusCode}; -use tracing::instrument; - -use crate::helper_layers::map_future::MapFuture; -use crate::service::http_client_service::get_client; -use crate::utils::x_forwarded_for; -use crate::BoxError; -use crate::SgBody; -use crate::SgResponse; -use crate::SgResponseExt; - -pub mod echo; -pub mod http_client_service; -pub mod static_file_service; -pub mod ws_client_service; -pub(crate) const FILE_SCHEMA: &str = "file"; -pub trait CloneHyperService: hyper::service::Service { - fn clone_box(&self) -> Box + Send + Sync>; -} - -impl CloneHyperService for T -where - T: hyper::service::Service + Send + Sync + Clone + 'static, -{ - fn clone_box(&self) -> Box + Send + Sync> { - Box::new(self.clone()) - } -} -pub struct ArcHyperService { - pub boxed: Arc< - dyn CloneHyperService, Response = Response, Error = Infallible, Future = BoxFuture<'static, Result, Infallible>>> + Send + Sync, - >, -} - -impl std::fmt::Debug for ArcHyperService { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ArcHyperService").finish() - } -} - -impl Clone for ArcHyperService { - fn clone(&self) -> Self { - Self { boxed: self.boxed.clone() } - } -} - -impl ArcHyperService { - pub fn new(service: T) -> Self - where - T: Clone + CloneHyperService, Response = Response, Error = Infallible> + Send + Sync + 'static, - T::Future: Future, Infallible>> + 'static + Send, - { - let map_fut = MapFuture::new(service, |fut| Box::pin(fut) as _); - Self { boxed: Arc::new(map_fut) } - } -} - -impl hyper::service::Service> for ArcHyperService { - type Response = Response; - type Error = Infallible; - type Future = BoxFuture<'static, Result>; - - fn call(&self, req: Request) -> Self::Future { - Box::pin(self.boxed.call(req)) - } -} - -/// Http backend service -/// -/// This function could be a bottom layer of a http router, it will handle http and websocket request. -/// -/// This can handle both websocket connection and http request. -/// -/// # Errors -/// 1. Fail to collect body chunks -/// 2. Fail to upgrade -pub async fn http_backend_service_inner(mut req: Request) -> Result { - tracing::trace!(elapsed = ?req.extensions().get::().map(crate::extension::EnterTime::elapsed), "start a backend request"); - x_forwarded_for(&mut req)?; - if req.uri().scheme_str() == Some(FILE_SCHEMA) { - return Ok(static_file_service::static_file_service(req, Path::new("./")).await); - } - let mut client = get_client(); - let mut response = if req.headers().get(UPGRADE).is_some_and(|upgrade| upgrade.as_bytes().eq_ignore_ascii_case(b"websocket")) { - // we only support websocket upgrade now - // if !upgrade.as_bytes().eq_ignore_ascii_case(b"websocket") { - // return Ok(Response::with_code_message(StatusCode::NOT_IMPLEMENTED, "[Sg.Websocket] unsupported upgrade protocol")); - // } - // dump request - let (part, body) = req.into_parts(); - let body = body.dump().await?; - let req = Request::from_parts(part, body); - - // forward request - let resp = client.request(req.clone()).await; - - // dump response - let (part, body) = resp.into_parts(); - let body = body.dump().await?; - let resp = Response::from_parts(part, body); - - let req_for_upgrade = req.clone(); - let resp_for_upgrade = resp.clone(); - - // create forward task - tokio::task::spawn(async move { - // update both side - let (s, c) = futures_util::join!(hyper::upgrade::on(req_for_upgrade), hyper::upgrade::on(resp_for_upgrade)); - let upgrade_as_server = s?; - let upgrade_as_client = c?; - // start a websocket forward - ws_client_service::tcp_transfer(upgrade_as_server, upgrade_as_client).await?; - >::Ok(()) - }); - tracing::trace!(elapsed = ?resp.extensions().get::().map(crate::extension::EnterTime::elapsed), "finish backend websocket forward"); - // return response to client - resp - } else { - let resp = client.request(req).await; - tracing::trace!(elapsed = ?resp.extensions().get::().map(crate::extension::EnterTime::elapsed), "finish backend request"); - resp - }; - response.extensions_mut().insert(unsafe { crate::extension::FromBackend::new() }); - Ok(response) -} - -#[instrument] -pub async fn http_backend_service(req: Request) -> Result, Infallible> { - match http_backend_service_inner(req).await { - Ok(resp) => Ok(resp), - Err(err) => Ok(Response::with_code_message(StatusCode::BAD_GATEWAY, format!("[Sg.Client] Client error: {err}"))), - } -} - -pub fn get_http_backend_service() -> ArcHyperService { - ArcHyperService::new(hyper::service::service_fn(http_backend_service)) -} - -pub fn get_echo_service() -> ArcHyperService { - ArcHyperService::new(hyper::service::service_fn(echo::echo)) -} +pub mod gateway; diff --git a/crates/kernel/src/layers/gateway.rs b/crates/kernel/src/service/gateway.rs similarity index 71% rename from crates/kernel/src/layers/gateway.rs rename to crates/kernel/src/service/gateway.rs index ec86d423..8a189ae8 100644 --- a/crates/kernel/src/layers/gateway.rs +++ b/crates/kernel/src/service/gateway.rs @@ -1,25 +1,25 @@ pub mod builder; -use std::{collections::HashMap, convert::Infallible, ops::Index, sync::Arc}; +use std::{collections::HashMap, ops::Index, sync::Arc}; use crate::{ + backend_service::ArcHyperService, extension::{GatewayName, MatchedSgRouter}, helper_layers::{ map_request::{add_extension::add_extension, MapRequestLayer}, reload::Reloader, - route::{Route, Router}, + route::{Router, RouterService}, }, - service::ArcHyperService, - utils::fold_sg_layers::sg_layers, - SgBody, SgBoxLayer, + utils::fold_box_layers::fold_layers, + BoxLayer, SgBody, }; -use hyper::{header::HOST, Request, Response}; +use hyper::{header::HOST, Request}; use tower_layer::Layer; use tracing::{debug, instrument}; -use super::http_route::{match_hostname::HostnameTree, match_request::MatchRequest, SgHttpRoute, SgHttpRouter}; +use super::http_route::{match_hostname::HostnameTree, match_request::MatchRequest, HttpRoute, HttpRouter}; /**************************************************************************************** @@ -27,46 +27,61 @@ use super::http_route::{match_hostname::HostnameTree, match_request::MatchReques *****************************************************************************************/ -pub type SgGatewayRoute = Route; +pub type HttpRouterService = RouterService; #[derive(Debug)] -pub struct SgGatewayLayer { +pub struct Gateway { pub gateway_name: Arc, - pub http_routes: HashMap, - pub http_plugins: Vec, - pub http_fallback: SgBoxLayer, - pub http_route_reloader: Reloader, + pub http_routes: HashMap, + pub http_plugins: Vec, + pub http_fallback: ArcHyperService, + pub http_route_reloader: Reloader, pub ext: hyper::http::Extensions, } -impl SgGatewayLayer { +impl Gateway { /// Create a new gateway layer. /// # Arguments /// * `gateway_name` - The gateway name, this may be used by plugins. - pub fn builder(gateway_name: impl Into>) -> builder::SgGatewayLayerBuilder { - builder::SgGatewayLayerBuilder::new(gateway_name) + pub fn builder(gateway_name: impl Into>) -> builder::GatewayBuilder { + builder::GatewayBuilder::new(gateway_name) + } + pub fn as_service(&self) -> ArcHyperService { + let gateway_name = GatewayName::new(self.gateway_name.clone()); + let add_gateway_name_layer = MapRequestLayer::new(add_extension(gateway_name, true)); + let gateway_plugins = self.http_plugins.iter(); + let http_routes = self.http_routes.values(); + let route = create_http_router(http_routes, self.http_fallback.clone()); + #[cfg(feature = "reload")] + let service = { + let reloader = self.http_route_reloader.clone(); + reloader.into_layer().layer(route) + }; + #[cfg(not(feature = "reload"))] + let service = route; + ArcHyperService::new(add_gateway_name_layer.layer(fold_layers(gateway_plugins, ArcHyperService::new(service)))) } } #[derive(Debug, Clone)] -pub struct SgGatewayRoutedServices { +pub struct HttpRoutedService { services: Arc<[Vec]>, } #[derive(Debug, Clone)] -pub struct SgGatewayRouter { - pub routers: Arc<[SgHttpRouter]>, +pub struct GatewayRouter { + pub routers: Arc<[HttpRouter]>, pub hostname_tree: Arc>>, } -impl Index<(usize, usize)> for SgGatewayRoutedServices { +impl Index<(usize, usize)> for HttpRoutedService { type Output = ArcHyperService; fn index(&self, index: (usize, usize)) -> &Self::Output { #[allow(clippy::indexing_slicing)] &self.services.as_ref()[index.0][index.1] } } -impl Router for SgGatewayRouter { +impl Router for GatewayRouter { type Index = (usize, usize); #[instrument(skip_all, fields(uri = req.uri().to_string(), method = req.method().as_str(), host = ?req.headers().get(HOST) ))] fn route(&self, req: &mut Request) -> Option { @@ -101,35 +116,7 @@ impl Router for SgGatewayRouter { } } -impl Layer for SgGatewayLayer -where - S: Clone + hyper::service::Service, Error = Infallible, Response = Response> + Send + Sync + 'static, - >>::Future: std::marker::Send, -{ - type Service = ArcHyperService; - - fn layer(&self, inner: S) -> Self::Service { - let gateway_name = GatewayName::new(self.gateway_name.clone()); - let add_gateway_name_layer = MapRequestLayer::new(add_extension(gateway_name, true)); - let gateway_plugins = self.http_plugins.iter(); - let http_routes = self.http_routes.values(); - let route = create_http_router(http_routes, &self.http_fallback, inner); - #[cfg(feature = "reload")] - let service = { - let reloader = self.http_route_reloader.clone(); - reloader.into_layer().layer(route) - }; - #[cfg(not(feature = "reload"))] - let service = route; - ArcHyperService::new(add_gateway_name_layer.layer(sg_layers(gateway_plugins, ArcHyperService::new(service)))) - } -} - -pub fn create_http_router<'a, S>(routes: impl Iterator, fallback: &SgBoxLayer, inner: S) -> Route -where - S: Clone + hyper::service::Service, Error = Infallible, Response = Response> + Send + Sync + 'static, - >>::Future: std::marker::Send, -{ +pub fn create_http_router<'a>(routes: impl Iterator, fallback: ArcHyperService) -> RouterService { let mut services = Vec::new(); let mut routers = Vec::new(); let mut hostname_tree = HostnameTree::>::new(); @@ -140,7 +127,7 @@ where let mut rules_services = Vec::with_capacity(route.rules.len()); let mut rules_router = Vec::with_capacity(route.rules.len()); for rule in route.rules.iter() { - let rule_service = sg_layers(route.plugins.iter(), ArcHyperService::new(rule.layer(inner.clone()))); + let rule_service = fold_layers(route.plugins.iter(), ArcHyperService::new(rule.as_service())); rules_services.push(rule_service); rules_router.push(rule.r#match.clone()); } @@ -160,7 +147,7 @@ where } } services.push(rules_services); - routers.push(SgHttpRouter { + routers.push(HttpRouter { hostnames: route.hostnames.clone().into(), rules: rules_router.into_iter().map(|x| x.map(|v| v.into_iter().map(Arc::new).collect::>())).collect(), ext: route.ext.clone(), @@ -171,12 +158,12 @@ where // we put the highest priority at the front of the vector hostname_tree.iter_mut().for_each(|indices| indices.sort_unstable_by_key(|(_, p)| i16::MAX - *p)); debug!("hostname_tree: {hostname_tree:?}"); - Route::new( - SgGatewayRoutedServices { services: services.into() }, - SgGatewayRouter { + RouterService::new( + HttpRoutedService { services: services.into() }, + GatewayRouter { routers: routers.into(), hostname_tree: Arc::new(hostname_tree), }, - fallback.layer(inner), + fallback, ) } diff --git a/crates/kernel/src/layers/gateway/builder.rs b/crates/kernel/src/service/gateway/builder.rs similarity index 58% rename from crates/kernel/src/layers/gateway/builder.rs rename to crates/kernel/src/service/gateway/builder.rs index 518bc5d2..b9a6ff10 100644 --- a/crates/kernel/src/layers/gateway/builder.rs +++ b/crates/kernel/src/service/gateway/builder.rs @@ -1,39 +1,34 @@ use std::{collections::HashMap, sync::Arc}; +use hyper::{service::service_fn, Response}; + use crate::{ - helper_layers::{ - filter::{response_anyway::ResponseAnyway, FilterRequestLayer}, - function::FnLayer, - reload::Reloader, - }, - layers::http_route::SgHttpRoute, + helper_layers::{function::FnLayer, reload::Reloader}, + service::http_route::HttpRoute, utils::Snowflake, - SgBoxLayer, + ArcHyperService, BoxLayer, SgBody, }; -use super::{SgGatewayLayer, SgGatewayRoute}; +use super::{Gateway, HttpRouterService}; -pub struct SgGatewayLayerBuilder { +pub struct GatewayBuilder { pub gateway_name: Arc, - pub http_routers: HashMap, - pub http_plugins: Vec, - pub http_fallback: SgBoxLayer, - pub http_route_reloader: Reloader, + pub http_routers: HashMap, + pub http_plugins: Vec, + pub http_fallback: ArcHyperService, + pub http_route_reloader: Reloader, pub extensions: hyper::http::Extensions, pub x_request_id: bool, } -pub fn default_gateway_route_fallback() -> SgBoxLayer { - // static LAYER: OnceLock = OnceLock::new(); - // LAYER.get_or_init(|| { - // }) - SgBoxLayer::new(FilterRequestLayer::new(ResponseAnyway { - status: hyper::StatusCode::NOT_FOUND, - message: "[Sg.HttpRouteRule] no rule matched".to_string().into(), +/// return empty 404 not found +pub fn default_gateway_route_fallback() -> ArcHyperService { + ArcHyperService::new(service_fn(|_| async { + Ok(Response::builder().status(hyper::StatusCode::NOT_FOUND).body(SgBody::empty()).expect("bad response")) })) } -impl SgGatewayLayerBuilder { +impl GatewayBuilder { pub fn new(gateway_name: impl Into>) -> Self { Self { gateway_name: gateway_name.into(), @@ -49,30 +44,30 @@ impl SgGatewayLayerBuilder { self.x_request_id = enable; self } - pub fn http_router(mut self, route: SgHttpRoute) -> Self { + pub fn http_router(mut self, route: HttpRoute) -> Self { self.http_routers.insert(route.name.clone(), route); self } - pub fn http_routers(mut self, routes: impl IntoIterator) -> Self { + pub fn http_routers(mut self, routes: impl IntoIterator) -> Self { for (name, mut route) in routes { route.name = name.clone(); self.http_routers.insert(name, route); } self } - pub fn http_plugin(mut self, plugin: SgBoxLayer) -> Self { + pub fn http_plugin(mut self, plugin: BoxLayer) -> Self { self.http_plugins.push(plugin); self } - pub fn http_plugins(mut self, plugins: impl IntoIterator) -> Self { + pub fn http_plugins(mut self, plugins: impl IntoIterator) -> Self { self.http_plugins.extend(plugins); self } - pub fn http_fallback(mut self, fallback: SgBoxLayer) -> Self { + pub fn http_fallback(mut self, fallback: ArcHyperService) -> Self { self.http_fallback = fallback; self } - pub fn http_route_reloader(mut self, reloader: Reloader) -> Self { + pub fn http_route_reloader(mut self, reloader: Reloader) -> Self { self.http_route_reloader = reloader; self } @@ -80,13 +75,13 @@ impl SgGatewayLayerBuilder { self.extensions = extension; self } - pub fn build(self) -> SgGatewayLayer { + pub fn build(self) -> Gateway { let mut plugins = vec![]; if self.x_request_id { - plugins.push(SgBoxLayer::new(FnLayer::new_closure(crate::utils::x_request_id::))); + plugins.push(BoxLayer::new(FnLayer::new_closure(crate::utils::x_request_id::))); } plugins.extend(self.http_plugins); - SgGatewayLayer { + Gateway { gateway_name: self.gateway_name, http_routes: self.http_routers, http_plugins: plugins, diff --git a/crates/kernel/src/layers/http_route.rs b/crates/kernel/src/service/http_route.rs similarity index 61% rename from crates/kernel/src/layers/http_route.rs rename to crates/kernel/src/service/http_route.rs index 866cee8b..1e790e8c 100644 --- a/crates/kernel/src/layers/http_route.rs +++ b/crates/kernel/src/service/http_route.rs @@ -1,24 +1,23 @@ pub mod builder; pub mod match_hostname; pub mod match_request; -use std::{convert::Infallible, sync::Arc, time::Duration}; - +use std::{convert::Infallible, path::PathBuf, sync::Arc, time::Duration}; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); use crate::{ + backend_service::{http_backend_service, static_file_service::static_file_service, ArcHyperService}, extension::{BackendHost, Reflect}, helper_layers::random_pick, - service::ArcHyperService, - utils::{fold_sg_layers::sg_layers, schema_port::port_to_schema}, - SgBody, SgBoxLayer, + utils::{fold_box_layers::fold_layers, schema_port::port_to_schema}, + BoxLayer, SgBody, }; +use futures_util::future::BoxFuture; use hyper::{Request, Response}; -// use tower_http::timeout::{Timeout, TimeoutLayer}; - use tower_layer::Layer; use self::{ - builder::{SgHttpBackendLayerBuilder, SgHttpRouteLayerBuilder, SgHttpRouteRuleLayerBuilder}, + builder::{HttpBackendBuilder, HttpRouteBuilder, HttpRouteRuleBuilder}, match_request::HttpRouteMatch, }; @@ -29,22 +28,22 @@ use self::{ *****************************************************************************************/ #[derive(Debug)] -pub struct SgHttpRoute { +pub struct HttpRoute { pub name: String, pub hostnames: Vec, - pub plugins: Vec, - pub rules: Vec, + pub plugins: Vec, + pub rules: Vec, pub priority: i16, pub ext: hyper::http::Extensions, } -impl SgHttpRoute { - pub fn builder() -> SgHttpRouteLayerBuilder { - SgHttpRouteLayerBuilder::new() +impl HttpRoute { + pub fn builder() -> HttpRouteBuilder { + HttpRouteBuilder::new() } } #[derive(Debug, Clone)] -pub struct SgHttpRouter { +pub struct HttpRouter { pub hostnames: Arc<[String]>, pub rules: Arc<[Option]>>]>, pub ext: hyper::http::Extensions, @@ -57,56 +56,48 @@ pub struct SgHttpRouter { *****************************************************************************************/ #[derive(Debug)] -pub struct SgHttpRouteRuleLayer { +pub struct HttpRouteRule { pub r#match: Option>, - pub plugins: Vec, + pub plugins: Vec, timeouts: Option, - backends: Vec, + backends: Vec, pub ext: hyper::http::Extensions, } -impl SgHttpRouteRuleLayer { - pub fn builder() -> SgHttpRouteRuleLayerBuilder { - SgHttpRouteRuleLayerBuilder::new() +impl HttpRouteRule { + pub fn builder() -> HttpRouteRuleBuilder { + HttpRouteRuleBuilder::new() } -} - -impl Layer for SgHttpRouteRuleLayer -where - S: Clone + hyper::service::Service, Error = Infallible, Response = Response> + Send + Sync + 'static, - >>::Future: std::marker::Send, -{ - type Service = SgRouteRule; - - fn layer(&self, inner: S) -> Self::Service { + pub fn as_service(&self) -> HttpRouteRuleService { use crate::helper_layers::timeout::TimeoutLayer; let empty = self.backends.is_empty(); let filter_layer = self.plugins.iter(); - + let time_out = self.timeouts.unwrap_or(DEFAULT_TIMEOUT); let service = if empty { - sg_layers(filter_layer, ArcHyperService::new(TimeoutLayer::new(self.timeouts).layer(inner))) + fold_layers(filter_layer, ArcHyperService::new(TimeoutLayer::new(time_out).layer(HttpBackendService::http_default()))) } else { - let service_iter = self.backends.iter().map(|l| (l.weight, l.layer(inner.clone()))); + let service_iter = self.backends.iter().map(|l| (l.weight, l.as_service())); let random_picker = random_pick::RandomPick::new(service_iter); - sg_layers(filter_layer, ArcHyperService::new(TimeoutLayer::new(self.timeouts).layer(random_picker))) + fold_layers(filter_layer, ArcHyperService::new(TimeoutLayer::new(time_out).layer(random_picker))) }; let r#match = self.r#match.clone().map(|v| v.into_iter().map(Arc::new).collect::>()); - SgRouteRule { + HttpRouteRuleService { r#match, service, ext: self.ext.clone(), } } } + #[derive(Clone)] -pub struct SgRouteRule { +pub struct HttpRouteRuleService { pub r#match: Option]>>, pub service: ArcHyperService, pub ext: hyper::http::Extensions, } -impl hyper::service::Service> for SgRouteRule { +impl hyper::service::Service> for HttpRouteRuleService { type Response = Response; type Error = Infallible; type Future = >>::Future; @@ -129,68 +120,73 @@ impl hyper::service::Service> for SgRouteRule { *****************************************************************************************/ #[derive(Debug)] -pub struct SgHttpBackendLayer { - pub plugins: Vec, - pub host: Option, - pub port: Option, - pub scheme: Option, +pub struct HttpBackend { + pub plugins: Vec, + pub backend: Backend, pub weight: u16, pub timeout: Option, pub ext: hyper::http::Extensions, } -impl SgHttpBackendLayer { - pub fn builder() -> SgHttpBackendLayerBuilder { - SgHttpBackendLayerBuilder::new() +impl HttpBackend { + pub fn builder() -> HttpBackendBuilder { + HttpBackendBuilder::new() } -} - -impl Layer for SgHttpBackendLayer -where - S: Clone + hyper::service::Service, Error = Infallible, Response = Response> + Send + Sync + 'static, - >>::Future: std::marker::Send, -{ - type Service = SgHttpBackend; - - fn layer(&self, inner: S) -> Self::Service { - let timeout_layer = crate::helper_layers::timeout::TimeoutLayer::new(self.timeout); - let filtered = sg_layers(self.plugins.iter(), ArcHyperService::new(timeout_layer.layer(inner))); - SgHttpBackend { + pub fn as_service(&self) -> ArcHyperService { + let inner_service = HttpBackendService { weight: self.weight, - host: self.host.clone().map(Into::into), - port: self.port, - scheme: self.scheme.clone().map(Into::into), + backend: self.backend.clone().into(), timeout: self.timeout, - inner_service: filtered, ext: self.ext.clone(), - } + }; + let timeout_layer = crate::helper_layers::timeout::TimeoutLayer::new(self.timeout.unwrap_or(DEFAULT_TIMEOUT)); + let filtered = fold_layers(self.plugins.iter(), ArcHyperService::new(timeout_layer.layer(inner_service))); + filtered } } +#[derive(Clone, Debug)] +pub enum Backend { + Http { host: Option, port: Option, schema: Option }, + File { path: PathBuf }, +} + #[derive(Clone)] -pub struct SgHttpBackend { - pub host: Option>, - pub port: Option, - pub scheme: Option>, +pub struct HttpBackendService { + pub backend: Arc, pub weight: u16, pub timeout: Option, - pub inner_service: S, pub ext: hyper::http::Extensions, } -impl hyper::service::Service> for SgHttpBackend -where - S: Clone + hyper::service::Service, Response = Response, Error = Infallible> + Send + 'static, - >>::Future: Send + 'static, -{ +impl HttpBackendService { + pub fn http_default() -> Self { + Self { + backend: Arc::new(Backend::Http { + host: None, + port: None, + schema: None, + }), + weight: 1, + timeout: None, + ext: hyper::http::Extensions::new(), + } + } +} + +impl hyper::service::Service> for HttpBackendService { type Response = Response; type Error = Infallible; - type Future = S::Future; + type Future = BoxFuture<'static, Result, Infallible>>; fn call(&self, req: Request) -> Self::Future { - let map_request = match (self.host.clone(), self.port, self.scheme.clone()) { - (None, None, None) => None, - (host, port, schema) => Some(move |mut req: Request| { + let map_request = match self.backend.as_ref() { + Backend::Http { + host: None, + port: None, + schema: None, + } => None, + Backend::Http { host, port, schema } => Some(move |mut req: Request| { if let Some(ref host) = host { if let Some(reflect) = req.extensions_mut().get_mut::() { reflect.insert(BackendHost::new(host.clone())); @@ -222,9 +218,16 @@ where } req }), + Backend::File { .. } => None, }; - tracing::trace!(elapsed = ?req.extensions().get::().map(crate::extension::EnterTime::elapsed), "enter backend"); let req = if let Some(map_request) = map_request { map_request(req) } else { req }; - self.inner_service.call(req) + let backend = self.backend.clone(); + tracing::trace!(elapsed = ?req.extensions().get::().map(crate::extension::EnterTime::elapsed), "enter backend {backend:?}"); + Box::pin(async move { + match backend.as_ref() { + Backend::Http { .. } => http_backend_service(req).await, + Backend::File { path } => Ok(static_file_service(req, path).await), + } + }) } } diff --git a/crates/kernel/src/layers/http_route/builder.rs b/crates/kernel/src/service/http_route/builder.rs similarity index 55% rename from crates/kernel/src/layers/http_route/builder.rs rename to crates/kernel/src/service/http_route/builder.rs index a6f17a9b..51879211 100644 --- a/crates/kernel/src/layers/http_route/builder.rs +++ b/crates/kernel/src/service/http_route/builder.rs @@ -1,26 +1,26 @@ -use std::time::Duration; +use std::{fmt::Debug, path::PathBuf, time::Duration}; -use crate::{service::FILE_SCHEMA, SgBoxLayer}; +use crate::BoxLayer; -use super::{match_request::HttpRouteMatch, SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}; +use super::{match_request::HttpRouteMatch, Backend, HttpBackend, HttpRoute, HttpRouteRule}; #[derive(Debug)] -pub struct SgHttpRouteLayerBuilder { +pub struct HttpRouteBuilder { pub name: String, pub hostnames: Vec, - pub rules: Vec, - pub plugins: Vec, + pub rules: Vec, + pub plugins: Vec, pub priority: Option, pub extensions: hyper::http::Extensions, } -impl Default for SgHttpRouteLayerBuilder { +impl Default for HttpRouteBuilder { fn default() -> Self { Self::new() } } -impl SgHttpRouteLayerBuilder { +impl HttpRouteBuilder { pub fn new() -> Self { Self { name: Default::default(), @@ -39,19 +39,19 @@ impl SgHttpRouteLayerBuilder { self.hostnames = hostnames.into_iter().collect(); self } - pub fn rule(mut self, rule: SgHttpRouteRuleLayer) -> Self { + pub fn rule(mut self, rule: HttpRouteRule) -> Self { self.rules.push(rule); self } - pub fn rules(mut self, rules: impl IntoIterator) -> Self { + pub fn rules(mut self, rules: impl IntoIterator) -> Self { self.rules.extend(rules); self } - pub fn plugin(mut self, plugin: SgBoxLayer) -> Self { + pub fn plugin(mut self, plugin: BoxLayer) -> Self { self.plugins.push(plugin); self } - pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { + pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { self.plugins.extend(plugins); self } @@ -63,11 +63,11 @@ impl SgHttpRouteLayerBuilder { self.extensions = extensions; self } - pub fn build(mut self) -> SgHttpRoute { + pub fn build(mut self) -> HttpRoute { if self.hostnames.iter().any(|host| host == "*") { self.hostnames = vec!["*".to_string()] } - SgHttpRoute { + HttpRoute { plugins: self.plugins, hostnames: self.hostnames, rules: self.rules, @@ -79,20 +79,20 @@ impl SgHttpRouteLayerBuilder { } #[derive(Debug)] -pub struct SgHttpRouteRuleLayerBuilder { +pub struct HttpRouteRuleBuilder { r#match: Option>, - pub plugins: Vec, + pub plugins: Vec, timeouts: Option, - backends: Vec, + backends: Vec, pub extensions: hyper::http::Extensions, } -impl Default for SgHttpRouteRuleLayerBuilder { +impl Default for HttpRouteRuleBuilder { fn default() -> Self { Self::new() } } -impl SgHttpRouteRuleLayerBuilder { +impl HttpRouteRuleBuilder { pub fn new() -> Self { Self { r#match: None, @@ -117,11 +117,11 @@ impl SgHttpRouteRuleLayerBuilder { self.r#match = None; self } - pub fn plugin(mut self, plugin: SgBoxLayer) -> Self { + pub fn plugin(mut self, plugin: BoxLayer) -> Self { self.plugins.push(plugin); self } - pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { + pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { self.plugins.extend(plugins); self } @@ -129,16 +129,16 @@ impl SgHttpRouteRuleLayerBuilder { self.timeouts = Some(timeout); self } - pub fn backend(mut self, backend: SgHttpBackendLayer) -> Self { + pub fn backend(mut self, backend: HttpBackend) -> Self { self.backends.push(backend); self } - pub fn backends(mut self, backend: impl IntoIterator) -> Self { + pub fn backends(mut self, backend: impl IntoIterator) -> Self { self.backends.extend(backend); self } - pub fn build(self) -> SgHttpRouteRuleLayer { - SgHttpRouteRuleLayer { + pub fn build(self) -> HttpRouteRule { + HttpRouteRule { r#match: self.r#match, plugins: self.plugins, timeouts: self.timeouts, @@ -151,24 +151,50 @@ impl SgHttpRouteRuleLayerBuilder { self } } - +pub trait BackendKindBuilder: Default + Debug { + fn build(self) -> Backend; +} #[derive(Debug)] -pub struct SgHttpBackendLayerBuilder { - host: Option, - port: Option, - schema: Option, - pub plugins: Vec, +pub struct HttpBackendBuilder { + backend: B, + pub plugins: Vec, timeout: Option, weight: u16, pub extensions: hyper::http::Extensions, } -impl Default for SgHttpBackendLayerBuilder { +#[derive(Debug, Default, Clone)] +pub struct HttpBackendKindBuilder { + pub host: Option, + pub port: Option, + pub schema: Option, +} + +impl BackendKindBuilder for HttpBackendKindBuilder { + fn build(self) -> Backend { + Backend::Http { + host: self.host, + port: self.port, + schema: self.schema, + } + } +} +#[derive(Debug, Default, Clone)] + +pub struct FileBackendKindBuilder { + path: PathBuf, +} + +impl BackendKindBuilder for FileBackendKindBuilder { + fn build(self) -> Backend { + Backend::File { path: self.path } + } +} + +impl Default for HttpBackendBuilder { fn default() -> Self { Self { - host: None, - port: None, - schema: None, + backend: B::default(), plugins: Vec::new(), timeout: None, weight: 1, @@ -177,15 +203,46 @@ impl Default for SgHttpBackendLayerBuilder { } } -impl SgHttpBackendLayerBuilder { +impl HttpBackendBuilder { + pub fn path(mut self, path: impl Into) -> Self { + self.backend = FileBackendKindBuilder { path: path.into() }; + self + } +} + +impl HttpBackendBuilder { + pub fn host(mut self, host: impl Into) -> Self { + self.backend = HttpBackendKindBuilder { + host: Some(host.into()), + ..Default::default() + }; + self + } + pub fn port(mut self, port: u16) -> Self { + self.backend = HttpBackendKindBuilder { + port: Some(port), + ..Default::default() + }; + self + } + pub fn schema(mut self, schema: impl Into) -> Self { + self.backend = HttpBackendKindBuilder { + schema: Some(schema.into()), + ..Default::default() + }; + self + } +} + +impl HttpBackendBuilder { pub fn new() -> Self { Self::default() } - pub fn plugin(mut self, plugin: SgBoxLayer) -> Self { + pub fn plugin(mut self, plugin: BoxLayer) -> Self { self.plugins.push(plugin); self } - pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { + pub fn plugins(mut self, plugins: impl IntoIterator) -> Self { self.plugins.extend(plugins); self } @@ -197,31 +254,31 @@ impl SgHttpBackendLayerBuilder { self.weight = weight; self } - pub fn host(mut self, host: impl Into) -> Self { - self.host = Some(host.into()); - self - } - pub fn port(mut self, port: u16) -> Self { - self.port = Some(port); - self - } - pub fn schema(mut self, schema: impl Into) -> Self { - self.schema = Some(schema.into()); - self + pub fn http(self) -> HttpBackendBuilder { + HttpBackendBuilder { + backend: HttpBackendKindBuilder::default(), + plugins: self.plugins, + timeout: self.timeout, + weight: self.weight, + extensions: self.extensions, + } } - pub fn file(mut self) -> Self { - self.schema = Some(FILE_SCHEMA.into()); - self + pub fn file(self) -> HttpBackendBuilder { + HttpBackendBuilder { + backend: FileBackendKindBuilder::default(), + plugins: self.plugins, + timeout: self.timeout, + weight: self.weight, + extensions: self.extensions, + } } pub fn ext(mut self, extension: hyper::http::Extensions) -> Self { self.extensions = extension; self } - pub fn build(self) -> SgHttpBackendLayer { - SgHttpBackendLayer { - host: self.host.map(Into::into), - port: self.port, - scheme: self.schema.map(Into::into), + pub fn build(self) -> HttpBackend { + HttpBackend { + backend: self.backend.build(), plugins: self.plugins, timeout: self.timeout, weight: self.weight, diff --git a/crates/kernel/src/layers/http_route/match_hostname.rs b/crates/kernel/src/service/http_route/match_hostname.rs similarity index 100% rename from crates/kernel/src/layers/http_route/match_hostname.rs rename to crates/kernel/src/service/http_route/match_hostname.rs diff --git a/crates/kernel/src/layers/http_route/match_request.rs b/crates/kernel/src/service/http_route/match_request.rs similarity index 100% rename from crates/kernel/src/layers/http_route/match_request.rs rename to crates/kernel/src/service/http_route/match_request.rs diff --git a/crates/kernel/src/utils.rs b/crates/kernel/src/utils.rs index 0674cb96..2ede4524 100644 --- a/crates/kernel/src/utils.rs +++ b/crates/kernel/src/utils.rs @@ -1,4 +1,4 @@ -pub mod fold_sg_layers; +pub mod fold_box_layers; mod never; pub mod query_kv; pub use never::never; diff --git a/crates/kernel/src/utils/fold_box_layers.rs b/crates/kernel/src/utils/fold_box_layers.rs new file mode 100644 index 00000000..4e244be7 --- /dev/null +++ b/crates/kernel/src/utils/fold_box_layers.rs @@ -0,0 +1,8 @@ +use crate::{ArcHyperService, BoxLayer}; + +pub fn fold_layers<'a>(layers: impl Iterator + std::iter::DoubleEndedIterator, mut inner: ArcHyperService) -> ArcHyperService { + for l in layers.rev() { + inner = l.layer_boxed(inner); + } + inner +} diff --git a/crates/kernel/src/utils/fold_sg_layers.rs b/crates/kernel/src/utils/fold_sg_layers.rs deleted file mode 100644 index 25a5d473..00000000 --- a/crates/kernel/src/utils/fold_sg_layers.rs +++ /dev/null @@ -1,10 +0,0 @@ -use tower_layer::Layer; - -use crate::{ArcHyperService, SgBoxLayer}; - -pub fn sg_layers<'a>(layers: impl Iterator + std::iter::DoubleEndedIterator, mut inner: ArcHyperService) -> ArcHyperService { - for l in layers.rev() { - inner = l.layer(inner); - } - inner -} diff --git a/crates/kernel/src/utils/with_length_or_chunked.rs b/crates/kernel/src/utils/with_length_or_chunked.rs index 4cdde374..319cdff5 100644 --- a/crates/kernel/src/utils/with_length_or_chunked.rs +++ b/crates/kernel/src/utils/with_length_or_chunked.rs @@ -1,18 +1,15 @@ use crate::SgBody; use hyper::{header::HeaderValue, Response}; - pub fn with_length_or_chunked(resp: &mut Response) { + const CHUNKED: &str = "chunked"; resp.headers_mut().remove(hyper::header::CONTENT_LENGTH); - - let is_chunked = resp.headers().get_all(hyper::header::TRANSFER_ENCODING).iter().any(|v| v.as_bytes() == b"chunked"); if let Some(len) = resp.body().get_dumped().map(hyper::body::Bytes::len) { - if !is_chunked { - resp.headers_mut().insert( - hyper::header::CONTENT_LENGTH, - HeaderValue::from_str(len.to_string().as_str()).expect("digits should be valid header char"), - ); - } - } else { - resp.headers_mut().append(hyper::header::TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + resp.headers_mut().remove(hyper::header::TRANSFER_ENCODING); + resp.headers_mut().insert( + hyper::header::CONTENT_LENGTH, + HeaderValue::from_str(len.to_string().as_str()).expect("digits should be valid header char"), + ); + } else if !resp.headers().get_all(hyper::header::TRANSFER_ENCODING).iter().any(|v| v.as_bytes() == CHUNKED.as_bytes()) { + resp.headers_mut().append(hyper::header::TRANSFER_ENCODING, HeaderValue::from_static(CHUNKED)); } } diff --git a/crates/kernel/src/utils/x_forwarded_for.rs b/crates/kernel/src/utils/x_forwarded_for.rs index 57185141..de9c2a2a 100644 --- a/crates/kernel/src/utils/x_forwarded_for.rs +++ b/crates/kernel/src/utils/x_forwarded_for.rs @@ -1,8 +1,8 @@ use crate::BoxError; use hyper::{header::HeaderValue, Request}; -use crate::{extension::PeerAddr, header::X_FORWARDED_FOR, SgBody}; - +use crate::{extension::PeerAddr, SgBody}; +const X_FORWARDED_FOR: &str = "x-forwarded-for"; /// Add `x-forwarded-for` for request, based on [`PeerAddr`](`crate::extension::PeerAddr`) /// # Errors /// missing peer addr ext diff --git a/crates/kernel/tests/test_h2.rs b/crates/kernel/tests/test_h2.rs index 3023d36d..2f0247e7 100644 --- a/crates/kernel/tests/test_h2.rs +++ b/crates/kernel/tests/test_h2.rs @@ -8,12 +8,12 @@ use std::{ use axum_server::tls_rustls::RustlsConfig; use hyper::{client, Request}; use spacegate_kernel::{ - layers::{ + backend_service::{get_http_backend_service, http_backend_service}, + listener::SgListen, + service::{ gateway, - http_route::{match_request::HttpPathMatchRewrite, SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}, + http_route::{match_request::HttpPathMatchRewrite, HttpBackend, HttpRoute, HttpRouteRule}, }, - listener::SgListen, - service::{get_http_backend_service, http_backend_service}, SgBody, }; use tokio_rustls::rustls::ServerConfig; @@ -42,15 +42,15 @@ async fn test_h2() { async fn gateway() { let cancel = CancellationToken::default(); - let gateway = gateway::SgGatewayLayer::builder("test_h2") + let gateway = gateway::Gateway::builder("test_h2") .http_routers([( "test_h2".to_string(), - SgHttpRoute::builder().rule(SgHttpRouteRuleLayer::builder().match_all().backend(SgHttpBackendLayer::builder().host("[::]").port(9003).build()).build()).build(), + HttpRoute::builder().rule(HttpRouteRule::builder().match_all().backend(HttpBackend::builder().host("[::]").port(9003).build()).build()).build(), )]) .build(); let addr = SocketAddr::from_str("[::]:9002").expect("invalid host"); - let listener = SgListen::new(addr, gateway.layer(get_http_backend_service()), cancel, "listener").with_tls_config(tls_config()); + let listener = SgListen::new(addr, gateway.as_service(), cancel, "listener").with_tls_config(tls_config()); listener.listen().await.expect("fail to listen"); } diff --git a/crates/kernel/tests/test_https.rs b/crates/kernel/tests/test_https.rs index e58cb829..56d36cc3 100644 --- a/crates/kernel/tests/test_https.rs +++ b/crates/kernel/tests/test_https.rs @@ -1,16 +1,14 @@ use std::{net::SocketAddr, str::FromStr, time::Duration}; use spacegate_kernel::{ - layers::{ + listener::SgListen, + service::{ gateway, - http_route::{match_request::HttpPathMatchRewrite, SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}, + http_route::{match_request::HttpPathMatchRewrite, HttpBackend, HttpRoute, HttpRouteRule}, }, - listener::SgListen, - service::get_http_backend_service, }; use tokio_rustls::rustls::ServerConfig; use tokio_util::sync::CancellationToken; -use tower_layer::Layer; #[tokio::test] async fn test_https() { tokio::spawn(gateway()); @@ -30,17 +28,15 @@ async fn test_https() { async fn gateway() { let cancel = CancellationToken::default(); - let gateway = gateway::SgGatewayLayer::builder("test_multi_part") + let gateway = gateway::Gateway::builder("test_multi_part") .http_routers([( "test_upload".to_string(), - SgHttpRoute::builder() - .rule( - SgHttpRouteRuleLayer::builder().match_item(HttpPathMatchRewrite::prefix("/tls")).backend(SgHttpBackendLayer::builder().host("[::]").port(9003).build()).build(), - ) + HttpRoute::builder() + .rule(HttpRouteRule::builder().match_item(HttpPathMatchRewrite::prefix("/tls")).backend(HttpBackend::builder().host("[::]").port(9003).build()).build()) .rule( - SgHttpRouteRuleLayer::builder() + HttpRouteRule::builder() .match_item(HttpPathMatchRewrite::prefix("/baidu")) - .backend(SgHttpBackendLayer::builder().schema("https").host("www.baidu.com").port(443).build()) + .backend(HttpBackend::builder().schema("https").host("www.baidu.com").port(443).build()) .build(), ) .build(), @@ -57,13 +53,13 @@ async fn gateway() { .expect("fail to build tls config"); let http_listener = SgListen::new( SocketAddr::from_str("[::]:9080").expect("invalid host"), - gateway.layer(get_http_backend_service()), + gateway.as_service(), cancel.child_token(), "listener", ); let https_listener = SgListen::new( SocketAddr::from_str("[::]:9443").expect("invalid host"), - gateway.layer(get_http_backend_service()), + gateway.as_service(), cancel.child_token(), "listener", ) diff --git a/crates/kernel/tests/test_multi_part.rs b/crates/kernel/tests/test_multi_part.rs index 8a71485d..f0a2720e 100644 --- a/crates/kernel/tests/test_multi_part.rs +++ b/crates/kernel/tests/test_multi_part.rs @@ -5,17 +5,15 @@ use reqwest::{ Body, }; use spacegate_kernel::{ - layers::{ + listener::SgListen, + service::{ gateway, - http_route::{SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}, + http_route::{HttpBackend, HttpRoute, HttpRouteRule}, }, - listener::SgListen, - service::get_http_backend_service, }; use tokio::fs::File; use tokio_util::io::ReaderStream; use tokio_util::sync::CancellationToken; -use tower_layer::Layer; #[tokio::test] async fn test_multi_part() { tokio::spawn(gateway()); @@ -31,14 +29,14 @@ async fn test_multi_part() { async fn gateway() { let cancel = CancellationToken::default(); - let gateway = gateway::SgGatewayLayer::builder("test_multi_part") + let gateway = gateway::Gateway::builder("test_multi_part") .http_routers([( "test_upload".to_string(), - SgHttpRoute::builder().rule(SgHttpRouteRuleLayer::builder().match_all().backend(SgHttpBackendLayer::builder().host("[::]").port(9003).build()).build()).build(), + HttpRoute::builder().rule(HttpRouteRule::builder().match_all().backend(HttpBackend::builder().host("[::]").port(9003).build()).build()).build(), )]) .build(); let addr = SocketAddr::from_str("[::]:9002").expect("invalid host"); - let listener = SgListen::new(addr, gateway.layer(get_http_backend_service()), cancel, "listener"); + let listener = SgListen::new(addr, gateway.as_service(), cancel, "listener"); listener.listen().await.expect("fail to listen"); } diff --git a/crates/kernel/tests/test_websocket.rs b/crates/kernel/tests/test_websocket.rs index 2497d040..6bee78e9 100644 --- a/crates/kernel/tests/test_websocket.rs +++ b/crates/kernel/tests/test_websocket.rs @@ -3,15 +3,13 @@ use futures_util::{SinkExt, StreamExt}; use std::{net::SocketAddr, str::FromStr, time::Duration}; use spacegate_kernel::{ - layers::{ + listener::SgListen, + service::{ gateway, - http_route::{SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}, + http_route::{HttpBackend, HttpRoute, HttpRouteRule}, }, - listener::SgListen, - service::get_http_backend_service, }; use tokio_util::sync::CancellationToken; -use tower_layer::Layer; #[tokio::test] async fn test_ws() { tokio::spawn(gateway()); @@ -32,14 +30,14 @@ async fn test_ws() { async fn gateway() { let cancel = CancellationToken::default(); - let gateway = gateway::SgGatewayLayer::builder("test_websocket") + let gateway = gateway::Gateway::builder("test_websocket") .http_routers([( "ws".to_string(), - SgHttpRoute::builder().rule(SgHttpRouteRuleLayer::builder().match_all().backend(SgHttpBackendLayer::builder().host("[::]").port(9002).build()).build()).build(), + HttpRoute::builder().rule(HttpRouteRule::builder().match_all().backend(HttpBackend::builder().host("[::]").port(9002).build()).build()).build(), )]) .build(); let addr = SocketAddr::from_str("[::]:9003").expect("invalid host"); - let listener = SgListen::new(addr, gateway.layer(get_http_backend_service()), cancel, "listener"); + let listener = SgListen::new(addr, gateway.as_service(), cancel, "listener"); listener.listen().await.expect("fail to listen"); } diff --git a/crates/plugin/src/ext/redis/plugins.rs b/crates/plugin/src/ext/redis/plugins.rs index 9e70f8a4..36139a7d 100644 --- a/crates/plugin/src/ext/redis/plugins.rs +++ b/crates/plugin/src/ext/redis/plugins.rs @@ -1,6 +1,6 @@ use hyper::{header::HeaderName, Request}; use serde::{Deserialize, Serialize}; -use spacegate_kernel::{extension::MatchedSgRouter, layers::http_route::match_request::HttpPathMatchRewrite, SgBody}; +use spacegate_kernel::{extension::MatchedSgRouter, service::http_route::match_request::HttpPathMatchRewrite, SgBody}; pub mod redis_count; pub mod redis_dynamic_route; diff --git a/crates/plugin/src/ext/redis/plugins/redis_count.rs b/crates/plugin/src/ext/redis/plugins/redis_count.rs index 0f5bd43e..6ce2b95d 100644 --- a/crates/plugin/src/ext/redis/plugins/redis_count.rs +++ b/crates/plugin/src/ext/redis/plugins/redis_count.rs @@ -99,8 +99,8 @@ mod test { use hyper::header::AUTHORIZATION; use serde_json::json; use spacegate_kernel::{ - layers::http_route::match_request::{HttpPathMatchRewrite, HttpRouteMatch}, - service::get_echo_service, + backend_service::get_echo_service, + service::http_route::match_request::{HttpPathMatchRewrite, HttpRouteMatch}, }; use testcontainers_modules::redis::REDIS_PORT; diff --git a/crates/plugin/src/ext/redis/plugins/redis_limit.rs b/crates/plugin/src/ext/redis/plugins/redis_limit.rs index 5a312674..08b221e1 100644 --- a/crates/plugin/src/ext/redis/plugins/redis_limit.rs +++ b/crates/plugin/src/ext/redis/plugins/redis_limit.rs @@ -78,8 +78,8 @@ mod test { use serde_json::json; use spacegate_ext_redis::redis::AsyncCommands; use spacegate_kernel::{ - layers::http_route::match_request::{HttpPathMatchRewrite, HttpRouteMatch}, - service::get_echo_service, + backend_service::get_echo_service, + service::http_route::match_request::{HttpPathMatchRewrite, HttpRouteMatch}, }; use std::time::Duration; use testcontainers_modules::redis::REDIS_PORT; diff --git a/crates/plugin/src/ext/redis/plugins/redis_time_range.rs b/crates/plugin/src/ext/redis/plugins/redis_time_range.rs index 1d4de7eb..7f263d26 100644 --- a/crates/plugin/src/ext/redis/plugins/redis_time_range.rs +++ b/crates/plugin/src/ext/redis/plugins/redis_time_range.rs @@ -94,8 +94,8 @@ mod test { use hyper::header::AUTHORIZATION; use serde_json::json; use spacegate_kernel::{ - layers::http_route::match_request::{HttpMethodMatch, HttpPathMatchRewrite, HttpRouteMatch}, - service::get_echo_service, + backend_service::get_echo_service, + service::http_route::match_request::{HttpMethodMatch, HttpPathMatchRewrite, HttpRouteMatch}, }; use testcontainers_modules::redis::REDIS_PORT; diff --git a/crates/plugin/src/instance.rs b/crates/plugin/src/instance.rs index 4c62cb57..b73baccb 100644 --- a/crates/plugin/src/instance.rs +++ b/crates/plugin/src/instance.rs @@ -8,7 +8,7 @@ use std::{ }; use serde::{Deserialize, Serialize}; -use spacegate_kernel::{helper_layers::function::FnLayer, BoxError, BoxResult, SgBoxLayer}; +use spacegate_kernel::{helper_layers::function::FnLayer, BoxError, BoxLayer, BoxResult}; use spacegate_model::PluginConfig; use crate::mount::{MountPoint, MountPointIndex}; @@ -58,7 +58,7 @@ pub(crate) fn drop_trace() -> (DropTracer, DropMarker) { ) } -pub type BoxMakeFn = Box Result + Sync + Send + 'static>; +pub type BoxMakeFn = Box Result + Sync + Send + 'static>; type PluginInstanceHook = Box Result<(), BoxError> + Send + Sync + 'static>; #[derive(Default)] @@ -132,8 +132,8 @@ impl PluginInstance { Ok(()) } } - pub fn make(&self) -> SgBoxLayer { - SgBoxLayer::new(FnLayer::new(self.plugin_function.clone())) + pub fn make(&self) -> BoxLayer { + BoxLayer::new(FnLayer::new(self.plugin_function.clone())) } // if we don't clean the mount_points, it may cause a slow memory leak // we do it before new instance mounted diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 589530a6..95356c56 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -18,8 +18,8 @@ pub use serde_json::{Error as SerdeJsonError, Value as JsonValue}; pub use spacegate_kernel::helper_layers::filter::{Filter, FilterRequest, FilterRequestLayer}; pub use spacegate_kernel::helper_layers::function::Inner; pub use spacegate_kernel::BoxError; +pub use spacegate_kernel::BoxLayer; use spacegate_kernel::SgBody; -pub use spacegate_kernel::SgBoxLayer; pub mod error; pub mod model; pub mod mount; diff --git a/crates/plugin/src/model.rs b/crates/plugin/src/model.rs index d8b0f1bc..93f7be63 100644 --- a/crates/plugin/src/model.rs +++ b/crates/plugin/src/model.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use spacegate_kernel::layers::http_route::match_request::HttpPathMatchRewrite; +use spacegate_kernel::service::http_route::match_request::HttpPathMatchRewrite; #[derive(Default, Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] diff --git a/crates/plugin/src/mount.rs b/crates/plugin/src/mount.rs index 8d12029e..cf197c86 100644 --- a/crates/plugin/src/mount.rs +++ b/crates/plugin/src/mount.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use spacegate_kernel::{ - layers::{ - gateway::SgGatewayLayer, - http_route::{SgHttpBackendLayer, SgHttpRoute, SgHttpRouteRuleLayer}, + service::{ + gateway::Gateway, + http_route::{HttpBackend, HttpRoute, HttpRouteRule}, }, BoxError, }; @@ -97,7 +97,7 @@ pub trait MountPoint { fn mount(&mut self, instance: &mut PluginInstance) -> Result; } -impl MountPoint for SgGatewayLayer { +impl MountPoint for Gateway { fn mount(&mut self, instance: &mut PluginInstance) -> Result { let (tracer, marker) = drop_trace(); self.http_plugins.push(instance.make()); @@ -107,7 +107,7 @@ impl MountPoint for SgGatewayLayer { } } -impl MountPoint for SgHttpRoute { +impl MountPoint for HttpRoute { fn mount(&mut self, instance: &mut PluginInstance) -> Result { let (tracer, marker) = drop_trace(); self.plugins.push(instance.make()); @@ -117,7 +117,7 @@ impl MountPoint for SgHttpRoute { } } -impl MountPoint for SgHttpRouteRuleLayer { +impl MountPoint for HttpRouteRule { fn mount(&mut self, instance: &mut PluginInstance) -> Result { let (tracer, marker) = drop_trace(); self.plugins.push(instance.make()); @@ -127,7 +127,7 @@ impl MountPoint for SgHttpRouteRuleLayer { } } -impl MountPoint for SgHttpBackendLayer { +impl MountPoint for HttpBackend { fn mount(&mut self, instance: &mut PluginInstance) -> Result { let (tracer, marker) = drop_trace(); self.plugins.push(instance.make()); diff --git a/crates/plugin/src/plugins/inject.rs b/crates/plugin/src/plugins/inject.rs index 2075dcf9..16c31e4b 100644 --- a/crates/plugin/src/plugins/inject.rs +++ b/crates/plugin/src/plugins/inject.rs @@ -4,9 +4,9 @@ use crate::Plugin; use hyper::{header::HeaderName, Request}; use hyper::{Method, Response, Uri}; use serde::{Deserialize, Serialize}; +use spacegate_kernel::backend_service::http_client_service::get_client; use spacegate_kernel::extension::Reflect; use spacegate_kernel::helper_layers::function::Inner; -use spacegate_kernel::service::http_client_service::get_client; use spacegate_kernel::BoxError; use spacegate_kernel::SgBody; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); diff --git a/crates/plugin/src/plugins/maintenance.rs b/crates/plugin/src/plugins/maintenance.rs index 96df1cd2..dd293c13 100644 --- a/crates/plugin/src/plugins/maintenance.rs +++ b/crates/plugin/src/plugins/maintenance.rs @@ -203,9 +203,9 @@ mod test { use hyper::StatusCode; use hyper::{Method, Request, Version}; use serde_json::json; + use spacegate_kernel::backend_service::get_echo_service; use spacegate_kernel::extension::PeerAddr; use spacegate_kernel::helper_layers::function::Inner; - use spacegate_kernel::service::get_echo_service; use spacegate_kernel::BoxError; use spacegate_kernel::SgBody; use spacegate_model::{PluginInstanceId, PluginInstanceName}; diff --git a/crates/shell/src/config/matches_convert.rs b/crates/shell/src/config/matches_convert.rs index a6fda65f..3444b649 100644 --- a/crates/shell/src/config/matches_convert.rs +++ b/crates/shell/src/config/matches_convert.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use hyper::header::{HeaderName, HeaderValue}; use regex::Regex; use spacegate_config::model as config; -use spacegate_kernel::{layers::http_route::match_request as kernel, BoxError}; +use spacegate_kernel::{service::http_route::match_request as kernel, BoxError}; /// convert [`config::SgHttpRouteMatch`] into [`kernel::SgHttpRouteMatch`] pub(crate) fn convert_config_to_kernel(config_match: config::SgHttpRouteMatch) -> Result { diff --git a/crates/shell/src/server.rs b/crates/shell/src/server.rs index a68390e6..01592b47 100644 --- a/crates/shell/src/server.rs +++ b/crates/shell/src/server.rs @@ -9,10 +9,9 @@ use crate::config::{matches_convert::convert_config_to_kernel, plugin_filter_dto use spacegate_config::{BackendHost, Config, ConfigItem, PluginInstanceId}; use spacegate_kernel::{ helper_layers::reload::Reloader, - layers::gateway::{builder::default_gateway_route_fallback, create_http_router, SgGatewayRoute}, listener::SgListen, - service::get_http_backend_service, - ArcHyperService, BoxError, Layer, + service::gateway::{builder::default_gateway_route_fallback, create_http_router, HttpRouterService}, + ArcHyperService, BoxError, }; use spacegate_plugin::{mount::MountPointIndex, SgPluginRepository}; use std::sync::Arc; @@ -27,7 +26,7 @@ use tokio_util::sync::CancellationToken; fn collect_http_route( gateway_name: Arc, http_routes: impl IntoIterator, -) -> Result, BoxError> { +) -> Result, BoxError> { http_routes .into_iter() .map(|(name, route)| { @@ -47,7 +46,7 @@ fn collect_http_route( gateway: gateway_name.clone(), route: route_name.clone(), }; - let mut builder = spacegate_kernel::layers::http_route::SgHttpRouteRuleLayer::builder(); + let mut builder = spacegate_kernel::service::http_route::HttpRouteRule::builder(); builder = if let Some(matches) = route_rule.matches { builder.matches(matches.into_iter().map(convert_config_to_kernel).collect::, _>>()?) } else { @@ -65,14 +64,14 @@ fn collect_http_route( route: route_name.clone(), }; let host = backend.get_host(); - let mut builder = spacegate_kernel::layers::http_route::SgHttpBackendLayer::builder(); + let mut builder = spacegate_kernel::service::http_route::HttpBackend::builder(); let plugins = backend.plugins; #[cfg(feature = "k8s")] { use crate::extension::k8s_service::K8sService; use spacegate_config::model::BackendHost; use spacegate_kernel::helper_layers::map_request::{add_extension::add_extension, MapRequestLayer}; - use spacegate_kernel::SgBoxLayer; + use spacegate_kernel::BoxLayer; if let BackendHost::K8sService(data) = backend.host { let namespace_ext = K8sService(data.into()); // need to add to front @@ -83,12 +82,13 @@ fn collect_http_route( if let Some(timeout) = backend.timeout_ms.map(|timeout| Duration::from_millis(timeout as u64)) { builder = builder.timeout(timeout) } - if let BackendHost::File { .. } = backend.host { - builder = builder.file() + let mut layer = if let BackendHost::File { path } = backend.host { + builder.file().path(path).build() } else if let Some(protocol) = backend.protocol { - builder = builder.schema(protocol.to_string()); - } - let mut layer = builder.build(); + builder.schema(protocol.to_string()).build() + } else { + builder.build() + }; global_batch_mount_plugin(plugins, &mut layer, mount_index); Result::<_, BoxError>::Ok(layer) }) @@ -103,7 +103,7 @@ fn collect_http_route( }) .collect::, _>>()?; let mut layer = - spacegate_kernel::layers::http_route::SgHttpRoute::builder().hostnames(route.hostnames.unwrap_or_default()).rules(rules).priority(route.priority).build(); + spacegate_kernel::service::http_route::HttpRoute::builder().hostnames(route.hostnames.unwrap_or_default()).rules(rules).priority(route.priority).build(); global_batch_mount_plugin(plugins, &mut layer, mount_index); Ok((name, layer)) }) @@ -115,21 +115,20 @@ pub(crate) fn create_service( gateway_name: &str, plugins: Vec, http_routes: BTreeMap, - reloader: Reloader, + reloader: Reloader, ) -> Result { let gateway_name: Arc = gateway_name.into(); let routes = collect_http_route(gateway_name.clone(), http_routes)?; - let mut layer = spacegate_kernel::layers::gateway::SgGatewayLayer::builder(gateway_name.clone()).http_routers(routes).http_route_reloader(reloader).build(); + let mut layer = spacegate_kernel::service::gateway::Gateway::builder(gateway_name.clone()).http_routers(routes).http_route_reloader(reloader).build(); global_batch_mount_plugin(plugins, &mut layer, MountPointIndex::Gateway { gateway: gateway_name }); - let backend_service = get_http_backend_service(); - let service = ArcHyperService::new(layer.layer(backend_service)); + let service = ArcHyperService::new(layer.as_service()); Ok(service) } /// create a new sg gateway route, which can be sent to reloader -pub(crate) fn create_router_service(gateway_name: Arc, http_routes: BTreeMap) -> Result { +pub(crate) fn create_router_service(gateway_name: Arc, http_routes: BTreeMap) -> Result { let routes = collect_http_route(gateway_name, http_routes.clone())?; - let service = create_http_router(routes.values(), &default_gateway_route_fallback(), get_http_backend_service()); + let service = create_http_router(routes.values(), default_gateway_route_fallback()); Ok(service) } @@ -145,7 +144,7 @@ pub struct RunningSgGateway { pub gateway_name: Arc, token: CancellationToken, handle: tokio::task::JoinHandle<()>, - pub reloader: Reloader, + pub reloader: Reloader, shutdown_timeout: Duration, } impl std::fmt::Debug for RunningSgGateway { @@ -235,7 +234,7 @@ impl RunningSgGateway { } } tracing::info!("[SG.Server] start gateway"); - let reloader = >::default(); + let reloader = >::default(); let service = create_service(&gateway.name, gateway.plugins, routes, reloader.clone())?; if gateway.listeners.is_empty() { error!("[SG.Server] Missing Listeners"); @@ -291,7 +290,7 @@ impl RunningSgGateway { if let Some(key) = key { info!("[SG.Server] using cert key {key:?}"); let mut tls_server_cfg = rustls::ServerConfig::builder().with_no_client_auth().with_single_cert(certs, key)?; - tls_server_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + tls_server_cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()]; tls_cfg.replace(tls_server_cfg); } else { error!("[SG.Server] Can not found a valid Tls private key"); diff --git a/resource/local-example/config.json b/resource/local-example/config.json index 2cd2a72f..955f0235 100644 --- a/resource/local-example/config.json +++ b/resource/local-example/config.json @@ -49,7 +49,7 @@ { "host": { "kind": "File", - "path": "" + "path": "./" }, "port": 80, "timeout_ms": null,