From 7f078d7b3c0da7ce4c53ab0a7ad918e92dc9a590 Mon Sep 17 00:00:00 2001 From: Alex Konradi Date: Tue, 5 Dec 2023 14:23:45 -0500 Subject: [PATCH] Accept body in response without content-length --- rust/net/src/infra/errors.rs | 2 + rust/net/src/infra/http.rs | 98 ++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 15 deletions(-) diff --git a/rust/net/src/infra/errors.rs b/rust/net/src/infra/errors.rs index 7c1a4f93a1..b9ef3d0b87 100644 --- a/rust/net/src/infra/errors.rs +++ b/rust/net/src/infra/errors.rs @@ -56,6 +56,8 @@ pub enum NetError { ServerRequestMissingId, /// Failed while sending a request from the server to the incoming messages channel FailedToPassMessageToIncomingChannel, + /// An HTTP stream was interrupted while receiving data. + HttpInterruptedDuringReceive, } impl LogSafeDisplay for NetError {} diff --git a/rust/net/src/infra/http.rs b/rust/net/src/infra/http.rs index 4067a5a283..21ec0d1f3d 100644 --- a/rust/net/src/infra/http.rs +++ b/rust/net/src/infra/http.rs @@ -84,21 +84,28 @@ impl AggregatingHttpClient for AggregatingHttp2Client { let (parts, body) = res.into_parts(); - let content = match parts.headers.get(hyper::header::CONTENT_LENGTH) { - Some(content_length_str) => { - let content_length = content_length_str - .to_str() - .map_err(|_| NetError::ContentLengthHeaderInvalid)? - .parse::() - .map_err(|_| NetError::ContentLengthHeaderInvalid)?; - Limited::new(body, content_length) - .collect() - .await - .map_err(|_| NetError::ContentLengthHeaderDoesntMatchDataSize)? - .to_bytes() - } - None => Bytes::new(), - }; + let content_length = parts + .headers + .get(hyper::header::CONTENT_LENGTH) + .map(|c| { + c.to_str() + .ok() + .and_then(|s| s.parse().ok()) + .ok_or(NetError::ContentLengthHeaderInvalid) + }) + .transpose()?; + + let content = match content_length { + Some(content_length) => Limited::new(body, content_length) + .collect() + .await + .map_err(|_| NetError::ContentLengthHeaderDoesntMatchDataSize)?, + _ => body + .collect() + .await + .map_err(|_| NetError::HttpInterruptedDuringReceive)?, + } + .to_bytes(); Ok((parts, content)) } @@ -128,3 +135,64 @@ pub(crate) async fn http2_channel( connection, }) } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use lazy_static::lazy_static; + use warp::Filter as _; + + use crate::infra::test::shared::InMemoryWarpConnector; + + use super::*; + + const FAKE_PORT: u16 = 1212; + lazy_static! { + static ref FAKE_CONNECTION_PARAMS: ConnectionParams = ConnectionParams { + sni: "sni".into(), + host: "host".into(), + port: FAKE_PORT, + http_request_decorator: Default::default(), + certs: crate::infra::certs::RootCertificates::Native, + dns_resolver: crate::infra::dns::DnsResolver::Static + }; + } + + #[tokio::test] + async fn aggregating_client_accepts_response_without_content_length() { + // HTTP servers are not required to send a content-length header. + const FAKE_BODY: &str = "body"; + const FAKE_PATH_AND_QUERY: &str = "/path?query=true"; + + let h2_server = warp::get().and(warp::path("path")).and(warp::query()).then( + |query: HashMap| async move { + assert_eq!(query.get("query").map(String::as_str), Some("true")); + warp::reply::html(FAKE_BODY) + }, + ); + + let transport_connector = InMemoryWarpConnector::new(h2_server); + + let Http2Channel { + mut aggregating_client, + connection, + } = http2_channel(&transport_connector, &FAKE_CONNECTION_PARAMS) + .await + .expect("can connect"); + + let _connection_task = tokio::spawn(connection); + + let response = aggregating_client + .send_request_aggregate_response( + PathAndQuery::from_static(FAKE_PATH_AND_QUERY), + Builder::new(), + Bytes::new(), + ) + .await + .expect("gets response"); + + let (_parts, content) = response; + assert_eq!(content, FAKE_BODY); + } +}