Skip to content

Commit

Permalink
Enhanced error handling (#57)
Browse files Browse the repository at this point in the history
* Renamed function to make it more clear

* API update to accept Result<>

* _with_errors functions, docs and example
  • Loading branch information
abdolence authored Jun 20, 2024
1 parent 6871c74 commit b13feee
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 94 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"
name = "axum-streams"
readme = "README.md"
include = ["Cargo.toml", "src/**/*.rs", "README.md", "LICENSE"]
version = "0.16.0"
version = "0.17.0"

[badges]
maintenance = { status = "actively-developed" }
Expand Down Expand Up @@ -55,6 +55,8 @@ tower-service = "0.3"
tokio = { version = "1", features = ["full"] }
prost = { version= "0.12", features = ["prost-derive"] }
arrow = { version = "52", features = ["ipc"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3"}

[[example]]
name = "json-example"
Expand Down Expand Up @@ -86,6 +88,11 @@ name = "json-with-buffering"
path = "examples/json-with-buffering.rs"
required-features = ["json"]

[[example]]
name = "json-with-errors-example"
path = "examples/json-with-errors-example.rs"
required-features = ["json"]

[[example]]
name = "arrow-example"
path = "examples/arrow-example.rs"
Expand Down
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,37 @@ You can change this is using `StreamAsOptions`:
.json_array(source_test_stream())
```

## Error handling
The library provides a way to propagate errors in the stream:

```rust
struct MyError {
message: String,
}

impl Into<axum::Error> for MyError {
fn into(self) -> axum::Error {
axum::Error::new(self.message)
}
}

fn my_source_stream() -> impl Stream<Item=Result<MyTestStructure, MyError>> {
// Simulating a stream with a plain vector and throttling to show how it works
stream::iter(vec![
Ok(MyTestStructure {
some_test_field: "test1".to_string()
}); 1000
])
}

async fn test_json_array_stream() -> impl IntoResponse {
// Use _with_errors functions or directly `StreamBodyAs::with_options`
// to produce a stream with errors
StreamBodyAs::json_array_with_errors(source_test_stream())
}

```

## JSON array inside another object
Sometimes you need to include your array inside some object, e.g.:
```json
Expand Down
72 changes: 72 additions & 0 deletions examples/json-with-errors-example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use futures::{stream, Stream, StreamExt};

use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;

use axum_streams::*;

#[derive(Debug, Clone, Deserialize, Serialize)]
struct MyTestStructure {
some_test_field: String,
}

struct MyError {
message: String,
}

impl Into<axum::Error> for MyError {
fn into(self) -> axum::Error {
axum::Error::new(self.message)
}
}

fn source_test_stream() -> impl Stream<Item = Result<MyTestStructure, MyError>> {
// Simulating a stream with a plain vector and throttling to show how it works
tokio_stream::StreamExt::throttle(
stream::iter(vec![
MyTestStructure {
some_test_field: "test1".to_string()
};
10000
])
.enumerate()
.map(|(idx, item)| {
if idx != 0 && idx % 10 == 0 {
Err(MyError {
message: format!("Error at index {}", idx),
})
} else {
Ok(item)
}
}),
std::time::Duration::from_millis(500),
)
}

async fn test_json_array_stream() -> impl IntoResponse {
StreamBodyAs::json_array_with_errors(source_test_stream())
}

async fn test_json_nl_stream() -> impl IntoResponse {
StreamBodyAs::json_nl_with_errors(source_test_stream())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tracing_subscriber::fmt().with_target(false).init();

// build our application with a route
let app = Router::new()
// `GET /` goes to `root`
.route("/json-array-stream", get(test_json_array_stream))
.route("/json-nl-stream", get(test_json_nl_stream));

let listener = TcpListener::bind("127.0.0.1:8080").await?;

axum::serve(listener, app).await?;

Ok(())
}
68 changes: 62 additions & 6 deletions src/arrow_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl ArrowRecordBatchIpcStreamFormat {
impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, RecordBatch>,
stream: BoxStream<'b, Result<RecordBatch, axum::Error>>,
_: &'a StreamBodyAsOptions,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
fn write_batch(
Expand Down Expand Up @@ -81,8 +81,9 @@ impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
let batch_stream = Box::pin({
stream.scan(
(ipc_data_gen, dictionary_tracker, 0),
move |(ipc_data_gen, dictionary_tracker, idx), batch| {
futures::future::ready({
move |(ipc_data_gen, dictionary_tracker, idx), batch_res| match batch_res {
Err(e) => futures::future::ready(Some(Err(e))),
Ok(batch) => futures::future::ready({
let prepend_schema = if *idx == 0 {
Some(batch_schema.clone())
} else {
Expand All @@ -98,7 +99,7 @@ impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
)
.map_err(axum::Error::new);
Some(bytes)
})
}),
},
)
});
Expand All @@ -111,7 +112,7 @@ impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
Box::pin(batch_stream.chain(append_stream))
}

fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
let mut header_map = HeaderMap::new();
header_map.insert(
http::header::CONTENT_TYPE,
Expand All @@ -127,13 +128,39 @@ impl<'a> crate::StreamBodyAs<'a> {
pub fn arrow_ipc<S>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::new(schema),
stream.map(Ok::<RecordBatch, axum::Error>),
)
}

pub fn arrow_ipc_with_errors<S, E>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream)
}

pub fn arrow_ipc_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream.map(Ok::<RecordBatch, axum::Error>),
)
}

pub fn arrow_ipc_with_options_errors<S, E>(
schema: SchemaRef,
stream: S,
options: IpcWriteOptions,
) -> Self
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
Expand All @@ -146,6 +173,18 @@ impl StreamBodyAsOptions {
pub fn arrow_ipc<'a, S>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::new(schema),
stream.map(Ok::<RecordBatch, axum::Error>),
self,
)
}

pub fn arrow_ipc_with_errors<'a, S, E>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
StreamBodyAs::with_options(ArrowRecordBatchIpcStreamFormat::new(schema), stream, self)
}
Expand All @@ -158,6 +197,23 @@ impl StreamBodyAsOptions {
) -> StreamBodyAs<'a>
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream.map(Ok::<RecordBatch, axum::Error>),
self,
)
}

pub fn arrow_ipc_with_options_errors<'a, S, E>(
self,
schema: SchemaRef,
stream: S,
options: IpcWriteOptions,
) -> StreamBodyAs<'a>
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
Expand Down Expand Up @@ -214,7 +270,7 @@ mod tests {
get(|| async move {
StreamBodyAs::new(
ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()),
test_stream,
test_stream.map(Ok::<_, axum::Error>),
)
}),
);
Expand Down
74 changes: 52 additions & 22 deletions src/csv_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ where
{
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, T>,
stream: BoxStream<'b, Result<T, axum::Error>>,
_: &'a StreamBodyAsOptions,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
let stream_with_header = self.has_headers;
Expand All @@ -110,29 +110,34 @@ where
let terminator = self.terminator;

Box::pin({
stream.enumerate().map(move |(index, obj)| {
let mut writer = csv::WriterBuilder::new()
.has_headers(index == 0 && stream_with_header)
.delimiter(stream_delimiter)
.flexible(stream_flexible)
.quote_style(stream_quote_style)
.quote(stream_quote)
.double_quote(stream_double_quote)
.escape(stream_escape)
.terminator(terminator)
.from_writer(vec![]);

writer.serialize(obj).map_err(axum::Error::new)?;
writer.flush().map_err(axum::Error::new)?;
writer
.into_inner()
.map_err(axum::Error::new)
.map(axum::body::Bytes::from)
})
stream
.enumerate()
.map(move |(index, obj_res)| match obj_res {
Err(e) => Err(e),
Ok(obj) => {
let mut writer = csv::WriterBuilder::new()
.has_headers(index == 0 && stream_with_header)
.delimiter(stream_delimiter)
.flexible(stream_flexible)
.quote_style(stream_quote_style)
.quote(stream_quote)
.double_quote(stream_double_quote)
.escape(stream_escape)
.terminator(terminator)
.from_writer(vec![]);

writer.serialize(obj).map_err(axum::Error::new)?;
writer.flush().map_err(axum::Error::new)?;
writer
.into_inner()
.map_err(axum::Error::new)
.map(axum::body::Bytes::from)
}
})
})
}

fn http_response_trailers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
let mut header_map = HeaderMap::new();
header_map.insert(
http::header::CONTENT_TYPE,
Expand All @@ -150,6 +155,18 @@ impl<'a> StreamBodyAs<'a> {
where
T: Serialize + Send + Sync + 'static,
S: Stream<Item = T> + 'a + Send,
{
Self::new(
CsvStreamFormat::new(false, b','),
stream.map(Ok::<T, axum::Error>),
)
}

pub fn csv_with_errors<S, T, E>(stream: S) -> Self
where
T: Serialize + Send + Sync + 'static,
S: Stream<Item = Result<T, E>> + 'a + Send,
E: Into<axum::Error> + 'static,
{
Self::new(CsvStreamFormat::new(false, b','), stream)
}
Expand All @@ -160,6 +177,19 @@ impl StreamBodyAsOptions {
where
T: Serialize + Send + Sync + 'static,
S: Stream<Item = T> + 'a + Send,
{
StreamBodyAs::with_options(
CsvStreamFormat::new(false, b','),
stream.map(Ok::<T, axum::Error>),
self,
)
}

pub fn csv_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
where
T: Serialize + Send + Sync + 'static,
S: Stream<Item = Result<T, E>> + 'a + Send,
E: Into<axum::Error> + 'static,
{
StreamBodyAs::with_options(CsvStreamFormat::new(false, b','), stream, self)
}
Expand Down Expand Up @@ -197,7 +227,7 @@ mod tests {
get(|| async {
StreamBodyAs::new(
CsvStreamFormat::new(false, b'.').with_delimiter(b','),
test_stream,
test_stream.map(Ok::<_, axum::Error>),
)
}),
);
Expand Down
Loading

0 comments on commit b13feee

Please sign in to comment.