diff --git a/Cargo.toml b/Cargo.toml index 9b64b2d..2a10cca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" name = "axum-streams" readme = "README.md" include = ["Cargo.toml", "src/**/*.rs", "README.md", "LICENSE"] -version = "0.16.0" +version = "0.17.0" [badges] maintenance = { status = "actively-developed" } @@ -55,6 +55,8 @@ tower-service = "0.3" tokio = { version = "1", features = ["full"] } prost = { version= "0.12", features = ["prost-derive"] } arrow = { version = "52", features = ["ipc"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3"} [[example]] name = "json-example" @@ -86,6 +88,11 @@ name = "json-with-buffering" path = "examples/json-with-buffering.rs" required-features = ["json"] +[[example]] +name = "json-with-errors-example" +path = "examples/json-with-errors-example.rs" +required-features = ["json"] + [[example]] name = "arrow-example" path = "examples/arrow-example.rs" diff --git a/README.md b/README.md index 68e43b4..efffa95 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,37 @@ You can change this is using `StreamAsOptions`: .json_array(source_test_stream()) ``` +## Error handling +The library provides a way to propagate errors in the stream: + +```rust +struct MyError { + message: String, +} + +impl Into for MyError { + fn into(self) -> axum::Error { + axum::Error::new(self.message) + } +} + +fn my_source_stream() -> impl Stream> { + // Simulating a stream with a plain vector and throttling to show how it works + stream::iter(vec![ + Ok(MyTestStructure { + some_test_field: "test1".to_string() + }); 1000 + ]) +} + +async fn test_json_array_stream() -> impl IntoResponse { + // Use _with_errors functions or directly `StreamBodyAs::with_options` + // to produce a stream with errors + StreamBodyAs::json_array_with_errors(source_test_stream()) +} + +``` + ## JSON array inside another object Sometimes you need to include your array inside some object, e.g.: ```json diff --git a/examples/json-with-errors-example.rs b/examples/json-with-errors-example.rs new file mode 100644 index 0000000..c1b40d5 --- /dev/null +++ b/examples/json-with-errors-example.rs @@ -0,0 +1,72 @@ +use axum::response::IntoResponse; +use axum::routing::*; +use axum::Router; +use futures::{stream, Stream, StreamExt}; + +use serde::{Deserialize, Serialize}; +use tokio::net::TcpListener; + +use axum_streams::*; + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct MyTestStructure { + some_test_field: String, +} + +struct MyError { + message: String, +} + +impl Into for MyError { + fn into(self) -> axum::Error { + axum::Error::new(self.message) + } +} + +fn source_test_stream() -> impl Stream> { + // Simulating a stream with a plain vector and throttling to show how it works + tokio_stream::StreamExt::throttle( + stream::iter(vec![ + MyTestStructure { + some_test_field: "test1".to_string() + }; + 10000 + ]) + .enumerate() + .map(|(idx, item)| { + if idx != 0 && idx % 10 == 0 { + Err(MyError { + message: format!("Error at index {}", idx), + }) + } else { + Ok(item) + } + }), + std::time::Duration::from_millis(500), + ) +} + +async fn test_json_array_stream() -> impl IntoResponse { + StreamBodyAs::json_array_with_errors(source_test_stream()) +} + +async fn test_json_nl_stream() -> impl IntoResponse { + StreamBodyAs::json_nl_with_errors(source_test_stream()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().with_target(false).init(); + + // build our application with a route + let app = Router::new() + // `GET /` goes to `root` + .route("/json-array-stream", get(test_json_array_stream)) + .route("/json-nl-stream", get(test_json_nl_stream)); + + let listener = TcpListener::bind("127.0.0.1:8080").await?; + + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/src/arrow_format.rs b/src/arrow_format.rs index 7395f2d..99431b1 100644 --- a/src/arrow_format.rs +++ b/src/arrow_format.rs @@ -33,7 +33,7 @@ impl ArrowRecordBatchIpcStreamFormat { impl StreamingFormat for ArrowRecordBatchIpcStreamFormat { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, RecordBatch>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { fn write_batch( @@ -81,8 +81,9 @@ impl StreamingFormat for ArrowRecordBatchIpcStreamFormat { let batch_stream = Box::pin({ stream.scan( (ipc_data_gen, dictionary_tracker, 0), - move |(ipc_data_gen, dictionary_tracker, idx), batch| { - futures::future::ready({ + move |(ipc_data_gen, dictionary_tracker, idx), batch_res| match batch_res { + Err(e) => futures::future::ready(Some(Err(e))), + Ok(batch) => futures::future::ready({ let prepend_schema = if *idx == 0 { Some(batch_schema.clone()) } else { @@ -98,7 +99,7 @@ impl StreamingFormat for ArrowRecordBatchIpcStreamFormat { ) .map_err(axum::Error::new); Some(bytes) - }) + }), }, ) }); @@ -111,7 +112,7 @@ impl StreamingFormat for ArrowRecordBatchIpcStreamFormat { Box::pin(batch_stream.chain(append_stream)) } - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -127,6 +128,17 @@ impl<'a> crate::StreamBodyAs<'a> { pub fn arrow_ipc(schema: SchemaRef, stream: S) -> Self where S: Stream + 'a + Send, + { + Self::new( + ArrowRecordBatchIpcStreamFormat::new(schema), + stream.map(Ok::), + ) + } + + pub fn arrow_ipc_with_errors(schema: SchemaRef, stream: S) -> Self + where + S: Stream> + 'a + Send, + E: Into, { Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream) } @@ -134,6 +146,21 @@ impl<'a> crate::StreamBodyAs<'a> { pub fn arrow_ipc_with_options(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self where S: Stream + 'a + Send, + { + Self::new( + ArrowRecordBatchIpcStreamFormat::with_options(schema, options), + stream.map(Ok::), + ) + } + + pub fn arrow_ipc_with_options_errors( + schema: SchemaRef, + stream: S, + options: IpcWriteOptions, + ) -> Self + where + S: Stream> + 'a + Send, + E: Into, { Self::new( ArrowRecordBatchIpcStreamFormat::with_options(schema, options), @@ -146,6 +173,18 @@ impl StreamBodyAsOptions { pub fn arrow_ipc<'a, S>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a> where S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + ArrowRecordBatchIpcStreamFormat::new(schema), + stream.map(Ok::), + self, + ) + } + + pub fn arrow_ipc_with_errors<'a, S, E>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a> + where + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options(ArrowRecordBatchIpcStreamFormat::new(schema), stream, self) } @@ -158,6 +197,23 @@ impl StreamBodyAsOptions { ) -> StreamBodyAs<'a> where S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + ArrowRecordBatchIpcStreamFormat::with_options(schema, options), + stream.map(Ok::), + self, + ) + } + + pub fn arrow_ipc_with_options_errors<'a, S, E>( + self, + schema: SchemaRef, + stream: S, + options: IpcWriteOptions, + ) -> StreamBodyAs<'a> + where + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options( ArrowRecordBatchIpcStreamFormat::with_options(schema, options), @@ -214,7 +270,7 @@ mod tests { get(|| async move { StreamBodyAs::new( ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()), - test_stream, + test_stream.map(Ok::<_, axum::Error>), ) }), ); diff --git a/src/csv_format.rs b/src/csv_format.rs index 1161758..4e4c7ca 100644 --- a/src/csv_format.rs +++ b/src/csv_format.rs @@ -97,7 +97,7 @@ where { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, T>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { let stream_with_header = self.has_headers; @@ -110,29 +110,34 @@ where let terminator = self.terminator; Box::pin({ - stream.enumerate().map(move |(index, obj)| { - let mut writer = csv::WriterBuilder::new() - .has_headers(index == 0 && stream_with_header) - .delimiter(stream_delimiter) - .flexible(stream_flexible) - .quote_style(stream_quote_style) - .quote(stream_quote) - .double_quote(stream_double_quote) - .escape(stream_escape) - .terminator(terminator) - .from_writer(vec![]); - - writer.serialize(obj).map_err(axum::Error::new)?; - writer.flush().map_err(axum::Error::new)?; - writer - .into_inner() - .map_err(axum::Error::new) - .map(axum::body::Bytes::from) - }) + stream + .enumerate() + .map(move |(index, obj_res)| match obj_res { + Err(e) => Err(e), + Ok(obj) => { + let mut writer = csv::WriterBuilder::new() + .has_headers(index == 0 && stream_with_header) + .delimiter(stream_delimiter) + .flexible(stream_flexible) + .quote_style(stream_quote_style) + .quote(stream_quote) + .double_quote(stream_double_quote) + .escape(stream_escape) + .terminator(terminator) + .from_writer(vec![]); + + writer.serialize(obj).map_err(axum::Error::new)?; + writer.flush().map_err(axum::Error::new)?; + writer + .into_inner() + .map_err(axum::Error::new) + .map(axum::body::Bytes::from) + } + }) }) } - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -150,6 +155,18 @@ impl<'a> StreamBodyAs<'a> { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + Self::new( + CsvStreamFormat::new(false, b','), + stream.map(Ok::), + ) + } + + pub fn csv_with_errors(stream: S) -> Self + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into + 'static, { Self::new(CsvStreamFormat::new(false, b','), stream) } @@ -160,6 +177,19 @@ impl StreamBodyAsOptions { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + CsvStreamFormat::new(false, b','), + stream.map(Ok::), + self, + ) + } + + pub fn csv_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a> + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into + 'static, { StreamBodyAs::with_options(CsvStreamFormat::new(false, b','), stream, self) } @@ -197,7 +227,7 @@ mod tests { get(|| async { StreamBodyAs::new( CsvStreamFormat::new(false, b'.').with_delimiter(b','), - test_stream, + test_stream.map(Ok::<_, axum::Error>), ) }), ); diff --git a/src/json_formats.rs b/src/json_formats.rs index 6e818cf..3146845 100644 --- a/src/json_formats.rs +++ b/src/json_formats.rs @@ -41,27 +41,30 @@ where { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, T>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { let stream_bytes: BoxStream> = Box::pin({ - stream.enumerate().map(|(index, obj)| { - let mut buf = BytesMut::new().writer(); - - let sep_write_res = if index != 0 { - buf.write_all(JSON_SEP_BYTES).map_err(axum::Error::new) - } else { - Ok(()) - }; - - match sep_write_res { - Ok(_) => { - match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) { - Ok(_) => Ok(buf.into_inner().freeze()), - Err(e) => Err(e), + stream.enumerate().map(|(index, obj_res)| match obj_res { + Err(e) => Err(e), + Ok(obj) => { + let mut buf = BytesMut::new().writer(); + + let sep_write_res = if index != 0 { + buf.write_all(JSON_SEP_BYTES).map_err(axum::Error::new) + } else { + Ok(()) + }; + + match sep_write_res { + Ok(_) => { + match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) { + Ok(_) => Ok(buf.into_inner().freeze()), + Err(e) => Err(e), + } } + Err(e) => Err(e), } - Err(e) => Err(e), } }) }); @@ -116,7 +119,7 @@ where Box::pin(prepend_stream.chain(stream_bytes.chain(append_stream))) } - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -143,24 +146,27 @@ where { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, T>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { Box::pin({ - stream.map(|obj| { - let mut buf = BytesMut::new().writer(); - match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) { - Ok(_) => match buf.write_all(JSON_NL_SEP_BYTES).map_err(axum::Error::new) { - Ok(_) => Ok(buf.into_inner().freeze()), + stream.map(|obj_res| match obj_res { + Err(e) => Err(e), + Ok(obj) => { + let mut buf = BytesMut::new().writer(); + match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) { + Ok(_) => match buf.write_all(JSON_NL_SEP_BYTES).map_err(axum::Error::new) { + Ok(_) => Ok(buf.into_inner().freeze()), + Err(e) => Err(e), + }, Err(e) => Err(e), - }, - Err(e) => Err(e), + } } }) }) } - fn http_response_trailers(&self, _: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, _: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -182,15 +188,44 @@ impl<'a> crate::StreamBodyAs<'a> { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + Self::new( + JsonArrayStreamFormat::new(), + stream.map(Ok::), + ) + } + + pub fn json_array_with_errors(stream: S) -> Self + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { Self::new(JsonArrayStreamFormat::new(), stream) } - pub fn json_array_with_envelope(stream: S, envelope: E, array_field: &str) -> Self + pub fn json_array_with_envelope(stream: S, envelope: EN, array_field: &str) -> Self where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, - E: Serialize + Send + Sync + 'static, + EN: Serialize + Send + Sync + 'static, + { + Self::new( + JsonArrayStreamFormat::with_envelope(envelope, array_field), + stream.map(Ok::), + ) + } + + pub fn json_array_with_envelope_errors( + stream: S, + envelope: EN, + array_field: &str, + ) -> Self + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, + EN: Serialize + Send + Sync + 'static, { Self::new( JsonArrayStreamFormat::with_envelope(envelope, array_field), @@ -202,6 +237,18 @@ impl<'a> crate::StreamBodyAs<'a> { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + Self::new( + JsonNewLineStreamFormat::new(), + stream.map(Ok::), + ) + } + + pub fn json_nl_with_errors(stream: S) -> Self + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { Self::new(JsonNewLineStreamFormat::new(), stream) } @@ -212,20 +259,52 @@ impl StreamBodyAsOptions { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + JsonArrayStreamFormat::new(), + stream.map(Ok::), + self, + ) + } + + pub fn json_array_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a> + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options(JsonArrayStreamFormat::new(), stream, self) } - pub fn json_array_with_envelope<'a, S, T, E>( + pub fn json_array_with_envelope<'a, S, T, EN>( self, stream: S, - envelope: E, + envelope: EN, array_field: &str, ) -> StreamBodyAs<'a> where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, - E: Serialize + Send + Sync + 'static, + EN: Serialize + Send + Sync + 'static, + { + StreamBodyAs::with_options( + JsonArrayStreamFormat::with_envelope(envelope, array_field), + stream.map(Ok::), + self, + ) + } + + pub fn json_array_with_envelope_errors<'a, S, T, E, EN>( + self, + stream: S, + envelope: EN, + array_field: &str, + ) -> StreamBodyAs<'a> + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, + EN: Serialize + Send + Sync + 'static, { StreamBodyAs::with_options( JsonArrayStreamFormat::with_envelope(envelope, array_field), @@ -238,6 +317,19 @@ impl StreamBodyAsOptions { where T: Serialize + Send + Sync + 'static, S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + JsonNewLineStreamFormat::new(), + stream.map(Ok::), + self, + ) + } + + pub fn json_nl_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a> + where + T: Serialize + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options(JsonNewLineStreamFormat::new(), stream, self) } @@ -269,7 +361,12 @@ mod tests { let app = Router::new().route( "/", - get(|| async { StreamBodyAs::new(JsonArrayStreamFormat::new(), test_stream) }), + get(|| async { + StreamBodyAs::new( + JsonArrayStreamFormat::new(), + test_stream.map(Ok::<_, axum::Error>), + ) + }), ); let client = TestClient::new(app).await; @@ -306,7 +403,12 @@ mod tests { let app = Router::new().route( "/", - get(|| async { StreamBodyAs::new(JsonNewLineStreamFormat::new(), test_stream) }), + get(|| async { + StreamBodyAs::new( + JsonNewLineStreamFormat::new(), + test_stream.map(Ok::<_, axum::Error>), + ) + }), ); let client = TestClient::new(app).await; @@ -364,7 +466,7 @@ mod tests { get(|| async { StreamBodyAs::new( JsonArrayStreamFormat::with_envelope(test_envelope, "my_array"), - test_stream, + test_stream.map(Ok::<_, axum::Error>), ) }), ); @@ -421,7 +523,7 @@ mod tests { get(|| async { StreamBodyAs::new( JsonArrayStreamFormat::with_envelope(test_envelope, "my_array"), - test_stream, + test_stream.map(Ok::<_, axum::Error>), ) }), ); diff --git a/src/protobuf_format.rs b/src/protobuf_format.rs index 5889f7b..b0764e1 100644 --- a/src/protobuf_format.rs +++ b/src/protobuf_format.rs @@ -20,7 +20,7 @@ where { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, T>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { fn write_protobuf_record(obj: T) -> Result, axum::Error> @@ -37,14 +37,17 @@ where } Box::pin({ - stream.map(move |obj| { - let write_protobuf_res = write_protobuf_record(obj); - write_protobuf_res.map(axum::body::Bytes::from) + stream.map(move |obj_res| match obj_res { + Err(e) => Err(e), + Ok(obj) => { + let write_protobuf_res = write_protobuf_record(obj); + write_protobuf_res.map(axum::body::Bytes::from) + } }) }) } - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -61,6 +64,18 @@ impl<'a> StreamBodyAs<'a> { where T: prost::Message + Send + Sync + 'static, S: Stream + 'a + Send, + { + Self::new( + ProtobufStreamFormat::new(), + stream.map(Ok::), + ) + } + + pub fn protobuf_with_errors(stream: S) -> Self + where + T: prost::Message + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { Self::new(ProtobufStreamFormat::new(), stream) } @@ -71,6 +86,19 @@ impl StreamBodyAsOptions { where T: prost::Message + Send + Sync + 'static, S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + ProtobufStreamFormat::new(), + stream.map(Ok::), + self, + ) + } + + pub fn protobuf_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a> + where + T: prost::Message + Send + Sync + 'static, + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options(ProtobufStreamFormat::new(), stream, self) } @@ -107,7 +135,12 @@ mod tests { let app = Router::new().route( "/", - get(|| async { StreamBodyAs::new(ProtobufStreamFormat::new(), test_stream) }), + get(|| async { + StreamBodyAs::new( + ProtobufStreamFormat::new(), + test_stream.map(Ok::<_, axum::Error>), + ) + }), ); let client = TestClient::new(app).await; diff --git a/src/stream_body_as.rs b/src/stream_body_as.rs index 3373aa5..73d404e 100644 --- a/src/stream_body_as.rs +++ b/src/stream_body_as.rs @@ -3,8 +3,8 @@ use axum::body::{Body, HttpBody}; use axum::response::{IntoResponse, Response}; use bytes::BytesMut; use futures::stream::BoxStream; -use futures::Stream; use futures::StreamExt; +use futures::{Stream, TryStreamExt}; use http::{HeaderMap, HeaderValue}; use http_body::Frame; use std::fmt::Formatter; @@ -24,26 +24,28 @@ impl<'a> std::fmt::Debug for StreamBodyAs<'a> { impl<'a> StreamBodyAs<'a> { /// Create a new `StreamBodyWith` providing a stream of your objects in the specified format. - pub fn new(stream_format: FMT, stream: S) -> Self + pub fn new(stream_format: FMT, stream: S) -> Self where FMT: StreamingFormat, - S: Stream + 'a + Send, + S: Stream> + 'a + Send, + E: Into, { Self::with_options(stream_format, stream, StreamBodyAsOptions::new()) } - pub fn with_options( + pub fn with_options( stream_format: FMT, stream: S, options: StreamBodyAsOptions, ) -> Self where FMT: StreamingFormat, - S: Stream + 'a + Send, + S: Stream> + 'a + Send, + E: Into, { Self { stream: Self::create_stream_frames(&stream_format, stream, &options), - headers: stream_format.http_response_trailers(&options), + headers: stream_format.http_response_headers(&options), } } @@ -62,18 +64,20 @@ impl<'a> StreamBodyAs<'a> { self } - fn create_stream_frames( + fn create_stream_frames( stream_format: &FMT, stream: S, options: &StreamBodyAsOptions, ) -> BoxStream<'a, Result, axum::Error>> where FMT: StreamingFormat, - S: Stream + 'a + Send, + S: Stream> + 'a + Send, + E: Into, { + let boxed_stream = Box::pin(stream.map_err(|e| e.into())); match (options.buffering_ready_items, options.buffering_bytes) { (Some(buffering_ready_items), _) => stream_format - .to_bytes_stream(Box::pin(stream), options) + .to_bytes_stream(boxed_stream, options) .ready_chunks(buffering_ready_items) .map(|chunks| { let mut buf = BytesMut::new(); @@ -84,11 +88,9 @@ impl<'a> StreamBodyAs<'a> { }) .boxed(), (_, Some(buffering_bytes)) => { - let bytes_stream = stream_format - .to_bytes_stream(Box::pin(stream), options) - .chain(futures::stream::once(futures::future::ready(Ok( - bytes::Bytes::new(), - )))); + let bytes_stream = stream_format.to_bytes_stream(boxed_stream, options).chain( + futures::stream::once(futures::future::ready(Ok(bytes::Bytes::new()))), + ); bytes_stream .scan( @@ -116,7 +118,7 @@ impl<'a> StreamBodyAs<'a> { .boxed() } (None, None) => stream_format - .to_bytes_stream(Box::pin(stream), options) + .to_bytes_stream(boxed_stream, options) .map(|res| res.map(Frame::data)) .boxed(), } @@ -198,7 +200,8 @@ mod tests { #[tokio::test] async fn test_stream_body_as() { let stream = futures::stream::iter(vec!["First".to_string(), "Second".to_string()]).boxed(); - let stream_body_as = StreamBodyAs::new(TextStreamFormat::new(), stream); + let stream_body_as = + StreamBodyAs::new(TextStreamFormat::new(), stream.map(Ok::<_, axum::Error>)); let response = stream_body_as.into_response(); assert_eq!( response.headers().get(http::header::CONTENT_TYPE).unwrap(), @@ -221,7 +224,7 @@ mod tests { .boxed(); let stream_body_as = StreamBodyAs::with_options( TextStreamFormat::new(), - stream, + stream.map(Ok::<_, axum::Error>), StreamBodyAsOptions::new().buffering_ready_items(2), ); let response = stream_body_as.into_response(); @@ -246,7 +249,7 @@ mod tests { .boxed(); let stream_body_as = StreamBodyAs::with_options( TextStreamFormat::new(), - stream, + stream.map(Ok::<_, axum::Error>), StreamBodyAsOptions::new().buffering_bytes(3), ); let response = stream_body_as.into_response(); diff --git a/src/stream_format.rs b/src/stream_format.rs index beb8c33..8b5043c 100644 --- a/src/stream_format.rs +++ b/src/stream_format.rs @@ -5,9 +5,9 @@ use http::HeaderMap; pub trait StreamingFormat { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, T>, + stream: BoxStream<'b, Result>, options: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result>; - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option; + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option; } diff --git a/src/text_format.rs b/src/text_format.rs index c519af6..b854ef8 100644 --- a/src/text_format.rs +++ b/src/text_format.rs @@ -17,7 +17,7 @@ impl TextStreamFormat { impl StreamingFormat for TextStreamFormat { fn to_bytes_stream<'a, 'b>( &'a self, - stream: BoxStream<'b, String>, + stream: BoxStream<'b, Result>, _: &'a StreamBodyAsOptions, ) -> BoxStream<'b, Result> { fn write_text_record(obj: String) -> Result, axum::Error> { @@ -25,10 +25,13 @@ impl StreamingFormat for TextStreamFormat { Ok(obj_vec) } - Box::pin(stream.map(move |obj| write_text_record(obj).map(|data| data.into()))) + Box::pin(stream.map(move |obj_res| match obj_res { + Err(e) => Err(e), + Ok(obj) => write_text_record(obj).map(|data| data.into()), + })) } - fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option { + fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option { let mut header_map = HeaderMap::new(); header_map.insert( http::header::CONTENT_TYPE, @@ -44,6 +47,17 @@ impl<'a> StreamBodyAs<'a> { pub fn text(stream: S) -> Self where S: Stream + 'a + Send, + { + Self::new( + TextStreamFormat::new(), + stream.map(Ok::), + ) + } + + pub fn text_with_errors(stream: S) -> Self + where + S: Stream> + 'a + Send, + E: Into, { Self::new(TextStreamFormat::new(), stream) } @@ -53,6 +67,18 @@ impl StreamBodyAsOptions { pub fn text<'a, S>(self, stream: S) -> StreamBodyAs<'a> where S: Stream + 'a + Send, + { + StreamBodyAs::with_options( + TextStreamFormat::new(), + stream.map(Ok::), + self, + ) + } + + pub fn text_with_errors<'a, S, E>(self, stream: S) -> StreamBodyAs<'a> + where + S: Stream> + 'a + Send, + E: Into, { StreamBodyAs::with_options(TextStreamFormat::new(), stream, self) } @@ -92,7 +118,12 @@ mod tests { let app = Router::new().route( "/", - get(|| async { StreamBodyAs::new(TextStreamFormat::new(), test_stream) }), + get(|| async { + StreamBodyAs::new( + TextStreamFormat::new(), + test_stream.map(Ok::<_, axum::Error>), + ) + }), ); let client = TestClient::new(app).await;