Skip to content

Commit

Permalink
Accept body in response without content-length
Browse files Browse the repository at this point in the history
  • Loading branch information
akonradi-signal authored Dec 5, 2023
1 parent 1f2d761 commit 7f078d7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 15 deletions.
2 changes: 2 additions & 0 deletions rust/net/src/infra/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub enum NetError {
ServerRequestMissingId,
/// Failed while sending a request from the server to the incoming messages channel
FailedToPassMessageToIncomingChannel,
/// An HTTP stream was interrupted while receiving data.
HttpInterruptedDuringReceive,
}

impl LogSafeDisplay for NetError {}
Expand Down
98 changes: 83 additions & 15 deletions rust/net/src/infra/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,28 @@ impl AggregatingHttpClient for AggregatingHttp2Client {

let (parts, body) = res.into_parts();

let content = match parts.headers.get(hyper::header::CONTENT_LENGTH) {
Some(content_length_str) => {
let content_length = content_length_str
.to_str()
.map_err(|_| NetError::ContentLengthHeaderInvalid)?
.parse::<usize>()
.map_err(|_| NetError::ContentLengthHeaderInvalid)?;
Limited::new(body, content_length)
.collect()
.await
.map_err(|_| NetError::ContentLengthHeaderDoesntMatchDataSize)?
.to_bytes()
}
None => Bytes::new(),
};
let content_length = parts
.headers
.get(hyper::header::CONTENT_LENGTH)
.map(|c| {
c.to_str()
.ok()
.and_then(|s| s.parse().ok())
.ok_or(NetError::ContentLengthHeaderInvalid)
})
.transpose()?;

let content = match content_length {
Some(content_length) => Limited::new(body, content_length)
.collect()
.await
.map_err(|_| NetError::ContentLengthHeaderDoesntMatchDataSize)?,
_ => body
.collect()
.await
.map_err(|_| NetError::HttpInterruptedDuringReceive)?,
}
.to_bytes();

Ok((parts, content))
}
Expand Down Expand Up @@ -128,3 +135,64 @@ pub(crate) async fn http2_channel<C: TransportConnector>(
connection,
})
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use lazy_static::lazy_static;
use warp::Filter as _;

use crate::infra::test::shared::InMemoryWarpConnector;

use super::*;

const FAKE_PORT: u16 = 1212;
lazy_static! {
static ref FAKE_CONNECTION_PARAMS: ConnectionParams = ConnectionParams {
sni: "sni".into(),
host: "host".into(),
port: FAKE_PORT,
http_request_decorator: Default::default(),
certs: crate::infra::certs::RootCertificates::Native,
dns_resolver: crate::infra::dns::DnsResolver::Static
};
}

#[tokio::test]
async fn aggregating_client_accepts_response_without_content_length() {
// HTTP servers are not required to send a content-length header.
const FAKE_BODY: &str = "body";
const FAKE_PATH_AND_QUERY: &str = "/path?query=true";

let h2_server = warp::get().and(warp::path("path")).and(warp::query()).then(
|query: HashMap<String, String>| async move {
assert_eq!(query.get("query").map(String::as_str), Some("true"));
warp::reply::html(FAKE_BODY)
},
);

let transport_connector = InMemoryWarpConnector::new(h2_server);

let Http2Channel {
mut aggregating_client,
connection,
} = http2_channel(&transport_connector, &FAKE_CONNECTION_PARAMS)
.await
.expect("can connect");

let _connection_task = tokio::spawn(connection);

let response = aggregating_client
.send_request_aggregate_response(
PathAndQuery::from_static(FAKE_PATH_AND_QUERY),
Builder::new(),
Bytes::new(),
)
.await
.expect("gets response");

let (_parts, content) = response;
assert_eq!(content, FAKE_BODY);
}
}

0 comments on commit 7f078d7

Please sign in to comment.