Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change implementation of incoming-request.authority #2684

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions crates/trigger-http/src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{net::SocketAddr, str, str::FromStr};

use crate::{Body, ChainedRequestHandler, HttpExecutor, HttpInstance, HttpTrigger, Store};
use crate::{
Body, ChainedRequestHandler, HttpExecutor, HttpInstance, HttpTrigger, SelfRequestOrigin, Store,
};
use anyhow::{anyhow, Context, Result};
use futures::TryFutureExt;
use http::{HeaderName, HeaderValue};
Expand Down Expand Up @@ -33,6 +35,7 @@ impl HttpExecutor for HttpHandlerExecutor {
route_match: &RouteMatch,
req: Request<Body>,
client_addr: SocketAddr,
self_origin: SelfRequestOrigin,
) -> Result<Response<Body>> {
let component_id = route_match.component_id();

Expand All @@ -46,7 +49,7 @@ impl HttpExecutor for HttpHandlerExecutor {
unreachable!()
};

set_http_origin_from_request(&mut store, engine.clone(), self, &req);
set_http_origin(&mut store, engine.clone(), self, self_origin);

// set the client tls options for the current component_id.
// The OutboundWasiHttpHandler in this file is only used
Expand Down Expand Up @@ -390,36 +393,32 @@ impl HandlerType {
}
}

fn set_http_origin_from_request(
fn set_http_origin(
store: &mut Store,
engine: Arc<TriggerAppEngine<HttpTrigger>>,
handler: &HttpHandlerExecutor,
req: &Request<Body>,
self_origin: SelfRequestOrigin,
) {
if let Some(authority) = req.uri().authority() {
if let Some(scheme) = req.uri().scheme_str() {
let origin = format!("{}://{}", scheme, authority);
if let Some(outbound_http_handle) = engine
.engine
.find_host_component_handle::<Arc<OutboundHttpComponent>>()
{
let outbound_http_data = store
.host_components_data()
.get_or_insert(outbound_http_handle);

outbound_http_data.origin.clone_from(&origin);
store.as_mut().data_mut().as_mut().allowed_hosts =
outbound_http_data.allowed_hosts.clone();
}
let origin = format!("{}://{}", self_origin.scheme, self_origin.authority);
if let Some(outbound_http_handle) = engine
.engine
.find_host_component_handle::<Arc<OutboundHttpComponent>>()
{
let outbound_http_data = store
.host_components_data()
.get_or_insert(outbound_http_handle);

let chained_request_handler = ChainedRequestHandler {
engine: engine.clone(),
executor: handler.clone(),
};
store.as_mut().data_mut().as_mut().origin = Some(origin);
store.as_mut().data_mut().as_mut().chained_handler = Some(chained_request_handler);
}
outbound_http_data.origin.clone_from(&origin);
store.as_mut().data_mut().as_mut().allowed_hosts = outbound_http_data.allowed_hosts.clone();
}

let chained_request_handler = ChainedRequestHandler {
engine: engine.clone(),
executor: handler.clone(),
self_origin,
};
store.as_mut().data_mut().as_mut().origin = Some(origin);
store.as_mut().data_mut().as_mut().chained_handler = Some(chained_request_handler);
}

fn contextualise_err(e: anyhow::Error) -> anyhow::Error {
Expand Down
63 changes: 53 additions & 10 deletions crates/trigger-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl HttpTrigger {
server_addr: SocketAddr,
client_addr: SocketAddr,
) -> Result<Response<Body>> {
set_req_uri(&mut req, scheme, server_addr)?;
set_req_uri(&mut req, scheme.clone())?;
strip_forbidden_headers(&mut req);

spin_telemetry::extract_trace_context(&req);
Expand Down Expand Up @@ -278,6 +278,12 @@ impl HttpTrigger {
let trigger = self.component_trigger_configs.get(component_id).unwrap();

let executor = trigger.executor.as_ref().unwrap_or(&HttpExecutorType::Http);
// Set the definition of outbound requests to `self` to be equal to
// the incoming request's scheme and the bound listening address.
let self_origin = SelfRequestOrigin {
scheme,
authority: server_addr.to_string(),
};

let res = match executor {
HttpExecutorType::Http => {
Expand All @@ -288,6 +294,7 @@ impl HttpTrigger {
&route_match,
req,
client_addr,
self_origin,
)
.await
}
Expand All @@ -302,6 +309,7 @@ impl HttpTrigger {
&route_match,
req,
client_addr,
self_origin,
)
.await
}
Expand Down Expand Up @@ -370,15 +378,20 @@ impl HttpTrigger {
stream: S,
server_addr: SocketAddr,
client_addr: SocketAddr,
scheme: Scheme,
) {
task::spawn(async move {
if let Err(e) = http1::Builder::new()
.keep_alive(true)
.serve_connection(
TokioIo::new(stream),
service_fn(move |request| {
self.clone()
.instrumented_service_fn(server_addr, client_addr, request)
self.clone().instrumented_service_fn(
server_addr,
client_addr,
scheme.clone(),
request,
)
}),
)
.await
Expand All @@ -392,6 +405,7 @@ impl HttpTrigger {
self: Arc<Self>,
server_addr: SocketAddr,
client_addr: SocketAddr,
scheme: Scheme,
request: Request<Incoming>,
) -> Result<Response<HyperOutgoingBody>> {
let span = http_span!(request, client_addr);
Expand All @@ -403,7 +417,7 @@ impl HttpTrigger {
body.map_err(wasmtime_wasi_http::hyper_response_error)
.boxed()
}),
Scheme::HTTP,
scheme,
server_addr,
client_addr,
)
Expand All @@ -419,7 +433,7 @@ impl HttpTrigger {
loop {
let (stream, client_addr) = listener.accept().await?;
self.clone()
.serve_connection(stream, listen_addr, client_addr);
.serve_connection(stream, listen_addr, client_addr, Scheme::HTTP);
}
}

Expand All @@ -435,7 +449,10 @@ impl HttpTrigger {
loop {
let (stream, addr) = listener.accept().await?;
match acceptor.accept(stream).await {
Ok(stream) => self.clone().serve_connection(stream, listen_addr, addr),
Ok(stream) => {
self.clone()
.serve_connection(stream, listen_addr, addr, Scheme::HTTPS)
}
Err(err) => tracing::error!(?err, "Failed to start TLS session"),
}
}
Expand Down Expand Up @@ -475,11 +492,21 @@ fn parse_listen_addr(addr: &str) -> anyhow::Result<SocketAddr> {

/// The incoming request's scheme and authority
///
/// The incoming request's URI is relative to the server, so we need to set the scheme and authority
fn set_req_uri(req: &mut Request<Body>, scheme: Scheme, addr: SocketAddr) -> Result<()> {
/// The incoming request's URI is relative to the server, so we need to set the scheme and authority.
/// The `Host` header is used to set the authority. This function will error if no `Host` header is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Host header is used to set the authority.

Should this also check the req.uri().authority()? I think it might be most correct to use either one, verifying that they are identical if both are present.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure that's right, based on hyperium/hyper#1612

/// present or if it is not parseable as an `Authority`.
fn set_req_uri(req: &mut Request<Body>, scheme: Scheme) -> Result<()> {
let uri = req.uri().clone();
let mut parts = uri.into_parts();
let authority = format!("{}:{}", addr.ip(), addr.port()).parse().unwrap();
let headers = req.headers();
let host_header = headers
.get(HOST)
.context("missing Host header")?
.to_str()
.context("Host header is not valid UTF-8")?;
let authority = host_header
.parse()
.context("Host header contains an invalid authority")?;
parts.scheme = Some(scheme);
parts.authority = Some(authority);
*req.uri_mut() = Uri::from_parts(parts).unwrap();
Expand Down Expand Up @@ -573,13 +600,22 @@ pub(crate) trait HttpExecutor: Clone + Send + Sync + 'static {
route_match: &RouteMatch,
req: Request<Body>,
client_addr: SocketAddr,
self_origin: SelfRequestOrigin,
) -> Result<Response<Body>>;
}

/// The origin of the `self` host for outbound requests.
#[derive(Clone)]
pub struct SelfRequestOrigin {
scheme: Scheme,
authority: String,
}

#[derive(Clone)]
struct ChainedRequestHandler {
engine: Arc<TriggerAppEngine<HttpTrigger>>,
executor: HttpHandlerExecutor,
self_origin: SelfRequestOrigin,
}

#[derive(Default)]
Expand Down Expand Up @@ -622,7 +658,14 @@ impl HttpRuntimeData {

let resp_fut = async move {
match handler
.execute(engine.clone(), base, &route_match, request, client_addr)
.execute(
engine.clone(),
base,
&route_match,
request,
client_addr,
chained_handler.self_origin,
)
.await
{
Ok(resp) => Ok(Ok(IncomingResponse {
Expand Down
3 changes: 2 additions & 1 deletion crates/trigger-http/src/wagi.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{io::Cursor, net::SocketAddr, sync::Arc};

use crate::HttpInstance;
use crate::{HttpInstance, SelfRequestOrigin};
use anyhow::{anyhow, ensure, Context, Result};
use async_trait::async_trait;
use http_body_util::BodyExt;
Expand Down Expand Up @@ -28,6 +28,7 @@ impl HttpExecutor for WagiHttpExecutor {
route_match: &RouteMatch,
req: Request<Body>,
client_addr: SocketAddr,
_self_origin: SelfRequestOrigin,
) -> Result<Response<Body>> {
let component = route_match.component_id();

Expand Down
2 changes: 1 addition & 1 deletion tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ Caused by:
}

#[test]
fn outbound_http_works() -> anyhow::Result<()> {
fn outbound_http_to_same_app_works() -> anyhow::Result<()> {
run_test(
"outbound-http-to-same-app",
SpinConfig {
Expand Down
2 changes: 1 addition & 1 deletion tests/runtime-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl RuntimeTest<InProcessSpin> {
pub fn run(&mut self) {
self.run_test(|env| {
let runtime = env.runtime_mut();
let response = runtime.make_http_request(Request::new(Method::Get, "/"))?;
let response = runtime.make_http_request(Request::full(Method::Get, "/", &[("Host", "example.com")], None))?;
if response.status() == 200 {
return Ok(());
}
Expand Down
14 changes: 9 additions & 5 deletions tests/testing-framework/src/runtimes/in_process_spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ impl InProcessSpin {
pub fn make_http_request(&self, req: Request<'_, &[u8]>) -> anyhow::Result<Response> {
tokio::runtime::Runtime::new()?.block_on(async {
let method: reqwest::Method = req.method.into();
let req = http::request::Request::builder()
let mut builder = http::request::Request::builder()
.method(method)
.uri(req.path)
// TODO(rylev): convert headers and body as well
.body(spin_http::body::empty())
.unwrap();
.uri(req.path);

for (key, value) in req.headers {
builder = builder.header(*key, *value);
}

// TODO(rylev): convert body as well
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A TODO IN YOUR CODE?!? 🙂

let req = builder.body(spin_http::body::empty()).unwrap();
let response = self
.trigger
.handle(
Expand Down
Loading