diff --git a/Cargo.toml b/Cargo.toml index 6bd3988..1e9c733 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ path = "src/lib.rs" [dependencies] axum = { version = "0.7" } bytes = "1" -futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1" http-body = "1" mime = "0.3" @@ -35,12 +34,14 @@ tokio-util = { version = "0.7" } futures = "0.3" csv = { version = "1.3", optional = true } prost = { version= "0.12", optional = true } +arrow = { version = "51", features = ["ipc"], optional = true } [features] default = [] json = ["dep:serde", "dep:serde_json"] csv = ["dep:csv", "dep:serde"] protobuf = ["dep:prost"] +arrow = ["dep:arrow"] text = [] [dev-dependencies] @@ -53,6 +54,7 @@ tower-layer = "0.3" tower-service = "0.3" tokio = { version = "1", features = ["full"] } prost = { version= "0.12", features = ["prost-derive"] } +arrow = { version = "51", features = ["ipc"] } [[example]] name = "json-example" @@ -79,5 +81,10 @@ name = "json-array-complex-structure" path = "examples/json-array-complex-structure.rs" required-features = ["json"] +[[example]] +name = "arrow-example" +path = "examples/arrow-example.rs" +required-features = ["arrow"] + [build-dependencies] cargo-husky = { version = "1.5", default-features = false, features = ["run-for-all", "prepush-hook", "run-cargo-fmt"] } diff --git a/examples/arrow-example.rs b/examples/arrow-example.rs new file mode 100644 index 0000000..0551376 --- /dev/null +++ b/examples/arrow-example.rs @@ -0,0 +1,49 @@ +use arrow::array::*; +use arrow::datatypes::*; +use axum::response::IntoResponse; +use axum::routing::*; +use axum::Router; +use std::sync::Arc; + +use futures::prelude::*; +use tokio::net::TcpListener; +use tokio_stream::StreamExt; + +use axum_streams::*; + +fn source_test_stream(schema: Arc) -> impl Stream { + // Simulating a stream with a plain vector and throttling to show how it works + stream::iter((0..10).map(move |_| { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])), + Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])), + Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])), + ], + ) + .unwrap() + })) + .throttle(std::time::Duration::from_millis(50)) +} + +async fn test_text_stream() -> impl IntoResponse { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + StreamBodyAs::arrow(schema.clone(), source_test_stream(schema.clone())) +} + +#[tokio::main] +async fn main() { + // build our application with a route + let app = Router::new() + // `GET /` goes to `root` + .route("/arrow-stream", get(test_text_stream)); + + let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + + axum::serve(listener, app).await.unwrap(); +} diff --git a/src/arrow_format.rs b/src/arrow_format.rs new file mode 100644 index 0000000..f9a3284 --- /dev/null +++ b/src/arrow_format.rs @@ -0,0 +1,162 @@ +use crate::StreamingFormat; +use arrow::array::RecordBatch; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; +use bytes::{BufMut, BytesMut}; +use futures::stream::BoxStream; +use futures::Stream; +use futures::StreamExt; +use http::HeaderMap; +use http_body::Frame; +use std::sync::Arc; + +pub struct ArrowRecordBatchStreamFormat { + schema: SchemaRef, + options: IpcWriteOptions, +} + +impl ArrowRecordBatchStreamFormat { + pub fn new(schema: Arc) -> Self { + Self::with_options(schema, IpcWriteOptions::default()) + } + + pub fn with_options(schema: Arc, options: IpcWriteOptions) -> Self { + Self { + schema: schema.clone(), + options: options.clone(), + } + } +} + +impl StreamingFormat for ArrowRecordBatchStreamFormat { + fn to_bytes_stream<'a, 'b>( + &'a self, + stream: BoxStream<'b, RecordBatch>, + ) -> BoxStream<'b, Result, axum::Error>> { + let schema = self.schema.clone(); + let options = self.options.clone(); + + let stream_bytes: BoxStream, axum::Error>> = Box::pin({ + stream.map(move |batch| { + let buf = BytesMut::new().writer(); + let mut writer = StreamWriter::try_new_with_options(buf, &schema, options.clone()) + .map_err(axum::Error::new)?; + writer.write(&batch).map_err(axum::Error::new)?; + writer.finish().map_err(axum::Error::new)?; + writer + .into_inner() + .map_err(axum::Error::new) + .map(|buf| buf.into_inner().freeze()) + .map(axum::body::Bytes::from) + .map(Frame::data) + }) + }); + + Box::pin(stream_bytes) + } + + fn http_response_trailers(&self) -> Option { + let mut header_map = HeaderMap::new(); + header_map.insert( + http::header::CONTENT_TYPE, + http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream"), + ); + Some(header_map) + } +} + +impl<'a> crate::StreamBodyAs<'a> { + pub fn arrow(schema: SchemaRef, stream: S) -> Self + where + S: Stream + 'a + Send, + { + Self::new(ArrowRecordBatchStreamFormat::new(schema), stream) + } + + pub fn arrow_with_options(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self + where + S: Stream + 'a + Send, + { + Self::new( + ArrowRecordBatchStreamFormat::with_options(schema, options), + stream, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_client::*; + use crate::StreamBodyAs; + use arrow::array::*; + use arrow::datatypes::*; + use axum::{routing::*, Router}; + use futures::stream; + use std::sync::Arc; + + #[tokio::test] + async fn serialize_arrow_stream_format() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + fn create_test_batch(schema_ref: SchemaRef) -> Vec { + let vec_schema = schema_ref.clone(); + (0i64..10i64) + .map(move |idx| { + RecordBatch::try_new( + vec_schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])), + Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])), + Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])), + Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])), + ], + ) + .unwrap() + }) + .collect() + } + + let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone()))); + + let app_schema = schema.clone(); + + let app = Router::new().route( + "/", + get(|| async move { + StreamBodyAs::new( + ArrowRecordBatchStreamFormat::new(app_schema.clone()), + test_stream, + ) + }), + ); + + let client = TestClient::new(app).await; + + let expected_proto_buf: Vec = create_test_batch(schema.clone()) + .iter() + .flat_map(|batch| { + let mut writer = StreamWriter::try_new(Vec::new(), &schema).expect("writer failed"); + writer.write(&batch).expect("write failed"); + writer.finish().expect("writer failed"); + writer.into_inner().expect("writer failed") + }) + .collect(); + + let res = client.get("/").send().await.unwrap(); + assert_eq!( + res.headers() + .get("content-type") + .and_then(|h| h.to_str().ok()), + Some("application/vnd.apache.arrow.stream") + ); + let body = res.bytes().await.unwrap().to_vec(); + + assert_eq!(body, expected_proto_buf); + } +} diff --git a/src/csv_format.rs b/src/csv_format.rs index b2a2b60..b2122fe 100644 --- a/src/csv_format.rs +++ b/src/csv_format.rs @@ -1,7 +1,7 @@ use crate::stream_format::StreamingFormat; +use futures::stream::BoxStream; use futures::Stream; -use futures_util::stream::BoxStream; -use futures_util::StreamExt; +use futures::StreamExt; use http::HeaderMap; use http_body::Frame; use serde::Serialize; @@ -159,7 +159,7 @@ mod tests { use crate::test_client::*; use crate::StreamBodyAs; use axum::{routing::*, Router}; - use futures_util::stream; + use futures::stream; use std::ops::Add; #[tokio::test] diff --git a/src/json_formats.rs b/src/json_formats.rs index 43d5ef8..c6e93db 100644 --- a/src/json_formats.rs +++ b/src/json_formats.rs @@ -1,9 +1,9 @@ use crate::stream_format::StreamingFormat; use crate::StreamFormatEnvelope; use bytes::{BufMut, BytesMut}; +use futures::stream::BoxStream; use futures::Stream; -use futures_util::stream::BoxStream; -use futures_util::StreamExt; +use futures::StreamExt; use http::HeaderMap; use http_body::Frame; use serde::Serialize; @@ -66,7 +66,7 @@ where }); let prepend_stream: BoxStream, axum::Error>> = - Box::pin(futures_util::stream::once(futures_util::future::ready({ + Box::pin(futures::stream::once(futures::future::ready({ if let Some(envelope) = &self.envelope { match serde_json::to_vec(&envelope.object) { Ok(envelope_bytes) if envelope_bytes.len() > 1 => { @@ -108,7 +108,7 @@ where }))); let append_stream: BoxStream, axum::Error>> = - Box::pin(futures_util::stream::once(futures_util::future::ready({ + Box::pin(futures::stream::once(futures::future::ready({ if self.envelope.is_some() { Ok::<_, axum::Error>(Frame::data(axum::body::Bytes::from( JSON_ARRAY_ENVELOP_END_BYTES, @@ -216,7 +216,7 @@ mod tests { use crate::test_client::*; use crate::StreamBodyAs; use axum::{routing::*, Router}; - use futures_util::stream; + use futures::stream; #[tokio::test] async fn serialize_json_array_stream_format() { diff --git a/src/lib.rs b/src/lib.rs index a985cb4..934e126 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,5 +100,10 @@ mod protobuf_format; #[cfg(feature = "protobuf")] pub use protobuf_format::ProtobufStreamFormat; +#[cfg(feature = "arrow")] +mod arrow_format; +#[cfg(feature = "arrow")] +pub use arrow_format::ArrowRecordBatchStreamFormat; + #[cfg(test)] mod test_client; diff --git a/src/protobuf_format.rs b/src/protobuf_format.rs index 631ae54..d296c9f 100644 --- a/src/protobuf_format.rs +++ b/src/protobuf_format.rs @@ -1,7 +1,7 @@ use crate::stream_format::StreamingFormat; +use futures::stream::BoxStream; use futures::Stream; -use futures_util::stream::BoxStream; -use futures_util::StreamExt; +use futures::StreamExt; use http::HeaderMap; use http_body::Frame; @@ -72,7 +72,7 @@ mod tests { use crate::test_client::*; use crate::StreamBodyAs; use axum::{routing::*, Router}; - use futures_util::stream; + use futures::stream; use prost::Message; #[tokio::test] diff --git a/src/stream_body_as.rs b/src/stream_body_as.rs index 0e51405..568df8d 100644 --- a/src/stream_body_as.rs +++ b/src/stream_body_as.rs @@ -1,8 +1,8 @@ use crate::stream_format::StreamingFormat; use axum::body::{Body, HttpBody}; use axum::response::{IntoResponse, Response}; +use futures::stream::BoxStream; use futures::Stream; -use futures_util::stream::BoxStream; use http::HeaderMap; use http_body::Frame; use std::fmt::Formatter; diff --git a/src/stream_format.rs b/src/stream_format.rs index f34d857..15789ee 100644 --- a/src/stream_format.rs +++ b/src/stream_format.rs @@ -1,4 +1,4 @@ -use futures_util::stream::BoxStream; +use futures::stream::BoxStream; use http::HeaderMap; pub trait StreamingFormat { diff --git a/src/text_format.rs b/src/text_format.rs index 06b7d3b..5503b36 100644 --- a/src/text_format.rs +++ b/src/text_format.rs @@ -1,7 +1,7 @@ use crate::stream_format::StreamingFormat; +use futures::stream::BoxStream; use futures::Stream; -use futures_util::stream::BoxStream; -use futures_util::StreamExt; +use futures::StreamExt; use http::HeaderMap; use http_body::Frame; @@ -55,7 +55,7 @@ mod tests { use crate::test_client::*; use crate::StreamBodyAs; use axum::{routing::*, Router}; - use futures_util::stream; + use futures::stream; #[tokio::test] async fn serialize_text_stream_format() {