Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into update/k8s
Browse files Browse the repository at this point in the history
  • Loading branch information
RWDai committed Apr 22, 2024
2 parents b8580f1 + 5d4ef83 commit e7e4786
Show file tree
Hide file tree
Showing 43 changed files with 590 additions and 563 deletions.
143 changes: 143 additions & 0 deletions crates/kernel/src/backend_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::convert::Infallible;
use std::path::Path;
use std::sync::Arc;

use futures_util::future::BoxFuture;
use futures_util::Future;
use hyper::{header::UPGRADE, Request, Response, StatusCode};
use tracing::instrument;

use crate::backend_service::http_client_service::get_client;
use crate::helper_layers::map_future::MapFuture;
use crate::utils::x_forwarded_for;
use crate::BoxError;
use crate::SgBody;
use crate::SgResponse;
use crate::SgResponseExt;

pub mod echo;
pub mod http_client_service;
pub mod static_file_service;
pub mod ws_client_service;
pub(crate) const FILE_SCHEMA: &str = "file";
pub trait CloneHyperService<R>: hyper::service::Service<R> {
fn clone_box(&self) -> Box<dyn CloneHyperService<R, Response = Self::Response, Error = Self::Error, Future = Self::Future> + Send + Sync>;
}

impl<R, T> CloneHyperService<R> for T
where
T: hyper::service::Service<R> + Send + Sync + Clone + 'static,
{
fn clone_box(&self) -> Box<dyn CloneHyperService<R, Response = T::Response, Error = T::Error, Future = T::Future> + Send + Sync> {
Box::new(self.clone())
}
}
pub struct ArcHyperService {
pub boxed: Arc<
dyn CloneHyperService<Request<SgBody>, Response = Response<SgBody>, Error = Infallible, Future = BoxFuture<'static, Result<Response<SgBody>, Infallible>>> + Send + Sync,
>,
}

impl std::fmt::Debug for ArcHyperService {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArcHyperService").finish()
}
}

impl Clone for ArcHyperService {
fn clone(&self) -> Self {
Self { boxed: self.boxed.clone() }
}
}

impl ArcHyperService {
pub fn new<T>(service: T) -> Self
where
T: Clone + CloneHyperService<Request<SgBody>, Response = Response<SgBody>, Error = Infallible> + Send + Sync + 'static,
T::Future: Future<Output = Result<Response<SgBody>, Infallible>> + 'static + Send,
{
let map_fut = MapFuture::new(service, |fut| Box::pin(fut) as _);
Self { boxed: Arc::new(map_fut) }
}
}

impl hyper::service::Service<Request<SgBody>> for ArcHyperService {
type Response = Response<SgBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn call(&self, req: Request<SgBody>) -> Self::Future {
Box::pin(self.boxed.call(req))
}
}

/// Http backend service
///
/// This function could be a bottom layer of a http router, it will handle http and websocket request.
///
/// This can handle both websocket connection and http request.
///
/// # Errors
/// 1. Fail to collect body chunks
/// 2. Fail to upgrade
pub async fn http_backend_service_inner(mut req: Request<SgBody>) -> Result<SgResponse, BoxError> {
tracing::trace!(elapsed = ?req.extensions().get::<crate::extension::EnterTime>().map(crate::extension::EnterTime::elapsed), "start a backend request");
x_forwarded_for(&mut req)?;
if req.uri().scheme_str() == Some(FILE_SCHEMA) {
return Ok(static_file_service::static_file_service(req, Path::new("./")).await);
}
let mut client = get_client();
let mut response = if req.headers().get(UPGRADE).is_some_and(|upgrade| upgrade.as_bytes().eq_ignore_ascii_case(b"websocket")) {
// dump request
let (part, body) = req.into_parts();
let body = body.dump().await?;
let req = Request::from_parts(part, body);

// forward request
let resp = client.request(req.clone()).await;

// dump response
let (part, body) = resp.into_parts();
let body = body.dump().await?;
let resp = Response::from_parts(part, body);

let req_for_upgrade = req.clone();
let resp_for_upgrade = resp.clone();

// create forward task
tokio::task::spawn(async move {
// update both side
let (s, c) = futures_util::join!(hyper::upgrade::on(req_for_upgrade), hyper::upgrade::on(resp_for_upgrade));
let upgrade_as_server = s?;
let upgrade_as_client = c?;
// start a websocket forward
ws_client_service::tcp_transfer(upgrade_as_server, upgrade_as_client).await?;
<Result<(), BoxError>>::Ok(())
});
tracing::trace!(elapsed = ?resp.extensions().get::<crate::extension::EnterTime>().map(crate::extension::EnterTime::elapsed), "finish backend websocket forward");
// return response to client
resp
} else {
let resp = client.request(req).await;
tracing::trace!(elapsed = ?resp.extensions().get::<crate::extension::EnterTime>().map(crate::extension::EnterTime::elapsed), "finish backend request");
resp
};
response.extensions_mut().insert(unsafe { crate::extension::FromBackend::new() });
Ok(response)
}

#[instrument]
pub async fn http_backend_service(req: Request<SgBody>) -> Result<Response<SgBody>, Infallible> {
match http_backend_service_inner(req).await {
Ok(resp) => Ok(resp),
Err(err) => Ok(Response::with_code_message(StatusCode::BAD_GATEWAY, format!("[Sg.Client] Client error: {err}"))),
}
}

pub fn get_http_backend_service() -> ArcHyperService {
ArcHyperService::new(hyper::service::service_fn(http_backend_service))
}

pub fn get_echo_service() -> ArcHyperService {
ArcHyperService::new(hyper::service::service_fn(echo::echo))
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,19 @@ fn get_rustls_config_dangerous() -> rustls::ClientConfig {
config
}

pub fn get_client() -> SgHttpClient {
pub fn get_client() -> HttpClient {
ClientRepo::global().get_default()
}

pub struct ClientRepo {
default: SgHttpClient,
repo: Mutex<HashMap<String, SgHttpClient>>,
default: HttpClient,
repo: Mutex<HashMap<String, HttpClient>>,
}

impl Default for ClientRepo {
fn default() -> Self {
let config = get_rustls_config_dangerous();
let default = SgHttpClient::new(config);
let default = HttpClient::new(config);
Self {
default,
repo: Default::default(),
Expand All @@ -95,19 +95,19 @@ impl Default for ClientRepo {

static mut GLOBAL: OnceLock<ClientRepo> = OnceLock::new();
impl ClientRepo {
pub fn get(&self, code: &str) -> Option<SgHttpClient> {
pub fn get(&self, code: &str) -> Option<HttpClient> {
self.repo.lock().expect("failed to lock client repo").get(code).cloned()
}
pub fn get_or_default(&self, code: &str) -> SgHttpClient {
pub fn get_or_default(&self, code: &str) -> HttpClient {
self.get(code).unwrap_or_else(|| self.default.clone())
}
pub fn get_default(&self) -> SgHttpClient {
pub fn get_default(&self) -> HttpClient {
self.default.clone()
}
pub fn register(&self, code: &str, client: SgHttpClient) {
pub fn register(&self, code: &str, client: HttpClient) {
self.repo.lock().expect("failed to lock client repo").insert(code.to_string(), client);
}
pub fn set_default(&mut self, client: SgHttpClient) {
pub fn set_default(&mut self, client: HttpClient) {
self.default = client;
}
pub fn global() -> &'static Self {
Expand All @@ -116,7 +116,7 @@ impl ClientRepo {

/// # Safety
/// This function is not thread safe, it should be called before any other thread is spawned.
pub unsafe fn set_global_default(client: SgHttpClient) {
pub unsafe fn set_global_default(client: HttpClient) {
GLOBAL.get_or_init(Default::default);
GLOBAL.get_mut().expect("global not set").set_default(client);
}
Expand All @@ -127,19 +127,19 @@ pub struct SgHttpClientConfig {
}

#[derive(Debug, Clone)]
pub struct SgHttpClient {
pub struct HttpClient {
inner: Client<HttpsConnector<HttpConnector>, SgBody>,
}

impl Default for SgHttpClient {
impl Default for HttpClient {
fn default() -> Self {
Self::new(rustls::ClientConfig::builder().with_native_roots().expect("failed to init rustls config").with_no_client_auth())
}
}

impl SgHttpClient {
impl HttpClient {
pub fn new(tls_config: rustls::ClientConfig) -> Self {
SgHttpClient {
HttpClient {
inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().enable_http2().build()),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use hyper::{
HeaderMap, Response, StatusCode,
};
use tokio::io::AsyncReadExt;
use tracing::{instrument, trace};

use crate::{extension::Reflect, SgBody, SgRequest, SgResponse};

Expand Down Expand Up @@ -38,14 +39,23 @@ pub fn cache_policy(metadata: &Metadata) -> bool {
size < (1 << 20)
}

///
#[instrument()]
pub async fn static_file_service(mut request: SgRequest, dir: &Path) -> SgResponse {
// request.headers().get()
let mut response = Response::builder().body(SgBody::empty()).expect("failed to build response");
if let Some(reflect) = request.extensions_mut().remove::<Reflect>() {
*response.extensions_mut() = reflect.into_inner();
}
let path = dir.join(request.uri().path()).canonicalize().unwrap_or(dir.to_path_buf());
let Ok(dir) = dir.canonicalize() else {
*response.status_mut() = StatusCode::FORBIDDEN;
return response;
};

let Ok(path) = dir.join(request.uri().path().trim_start_matches('/')).canonicalize() else {
*response.status_mut() = StatusCode::FORBIDDEN;
return response;
};

trace!("static file path: {:?}", path);
if !path.starts_with(dir) {
*response.status_mut() = StatusCode::FORBIDDEN;
return response;
Expand Down
2 changes: 1 addition & 1 deletion crates/kernel/src/extension/matched.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{helper_layers::route::Router, layers::http_route::match_request::HttpRouteMatch};
use crate::{helper_layers::route::Router, service::http_route::match_request::HttpRouteMatch};
use std::{ops::Deref, sync::Arc};

#[derive(Debug, Clone)]
Expand Down
4 changes: 0 additions & 4 deletions crates/kernel/src/header.rs

This file was deleted.

6 changes: 3 additions & 3 deletions crates/kernel/src/helper_layers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ mod test {
resp
}
}
use crate::SgBoxLayer;
use crate::BoxLayer;

use super::*;
#[test]
fn test_fn_layer() {
let status_message = Arc::new(<HashMap<StatusCode, String>>::default());
let boxed_layer = SgBoxLayer::new(FnLayer::new(MyPlugin::default()));
let boxed_layer2 = SgBoxLayer::new(FnLayer::new_closure(move |req, inner| {
let boxed_layer = BoxLayer::new(FnLayer::new(MyPlugin::default()));
let boxed_layer2 = BoxLayer::new(FnLayer::new_closure(move |req, inner| {
let host = req.headers().get("host");
if let Some(Ok(host)) = host.map(HeaderValue::to_str) {
println!("{host}");
Expand Down
10 changes: 7 additions & 3 deletions crates/kernel/src/helper_layers/random_pick.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ where
type Error = S::Error;
type Future = S::Future;

#[allow(clippy::indexing_slicing)]
fn call(&self, req: R) -> Self::Future {
let index = self.picker.sample(&mut rand::thread_rng());
#[allow(clippy::indexing_slicing)]
self.services[index].call(req)
if self.services.len() == 1 {
self.services[0].call(req)
} else {
let index = self.picker.sample(&mut rand::thread_rng());
self.services[index].call(req)
}
}
}
6 changes: 3 additions & 3 deletions crates/kernel/src/helper_layers/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub trait Router: Clone {
}

#[derive(Debug, Clone)]
pub struct Route<S, R, F>
pub struct RouterService<S, R, F>
where
R: Router,
{
Expand All @@ -21,7 +21,7 @@ where
router: R,
}

impl<S, R, F> Route<S, R, F>
impl<S, R, F> RouterService<S, R, F>
where
R: Router,
S: Index<R::Index>,
Expand All @@ -31,7 +31,7 @@ where
}
}

impl<S, R, F> hyper::service::Service<Request<SgBody>> for Route<S, R, F>
impl<S, R, F> hyper::service::Service<Request<SgBody>> for RouterService<S, R, F>
where
R: Router + Send + Sync + 'static,
R::Index: Send + Sync + 'static + Clone,
Expand Down
Loading

0 comments on commit e7e4786

Please sign in to comment.