Skip to content

Commit

Permalink
Apache Arrow streamming format support
Browse files Browse the repository at this point in the history
  • Loading branch information
abdolence committed Mar 30, 2024
1 parent df1dcf5 commit dd2b7ac
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 17 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ path = "src/lib.rs"
[dependencies]
axum = { version = "0.7" }
bytes = "1"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "1"
http-body = "1"
mime = "0.3"
Expand All @@ -35,12 +34,14 @@ tokio-util = { version = "0.7" }
futures = "0.3"
csv = { version = "1.3", optional = true }
prost = { version= "0.12", optional = true }
arrow = { version = "51", features = ["ipc"], optional = true }

[features]
default = []
json = ["dep:serde", "dep:serde_json"]
csv = ["dep:csv", "dep:serde"]
protobuf = ["dep:prost"]
arrow = ["dep:arrow"]
text = []

[dev-dependencies]
Expand All @@ -53,6 +54,7 @@ tower-layer = "0.3"
tower-service = "0.3"
tokio = { version = "1", features = ["full"] }
prost = { version= "0.12", features = ["prost-derive"] }
arrow = { version = "51", features = ["ipc"] }

[[example]]
name = "json-example"
Expand All @@ -79,5 +81,10 @@ name = "json-array-complex-structure"
path = "examples/json-array-complex-structure.rs"
required-features = ["json"]

[[example]]
name = "arrow-example"
path = "examples/arrow-example.rs"
required-features = ["arrow"]

[build-dependencies]
cargo-husky = { version = "1.5", default-features = false, features = ["run-for-all", "prepush-hook", "run-cargo-fmt"] }
49 changes: 49 additions & 0 deletions examples/arrow-example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use arrow::array::*;
use arrow::datatypes::*;
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::sync::Arc;

use futures::prelude::*;
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;

fn source_test_stream(schema: Arc<Schema>) -> impl Stream<Item = RecordBatch> {
// Simulating a stream with a plain vector and throttling to show how it works
stream::iter((0..10).map(move |_| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
],
)
.unwrap()
}))
.throttle(std::time::Duration::from_millis(50))
}

async fn test_text_stream() -> impl IntoResponse {
let schema = Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]));
StreamBodyAs::arrow(schema.clone(), source_test_stream(schema.clone()))
}

#[tokio::main]
async fn main() {
// build our application with a route
let app = Router::new()
// `GET /` goes to `root`
.route("/arrow-stream", get(test_text_stream));

let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

axum::serve(listener, app).await.unwrap();
}
162 changes: 162 additions & 0 deletions src/arrow_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::StreamingFormat;
use arrow::array::RecordBatch;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use std::sync::Arc;

pub struct ArrowRecordBatchStreamFormat {
schema: SchemaRef,
options: IpcWriteOptions,
}

impl ArrowRecordBatchStreamFormat {
pub fn new(schema: Arc<Schema>) -> Self {
Self::with_options(schema, IpcWriteOptions::default())
}

pub fn with_options(schema: Arc<Schema>, options: IpcWriteOptions) -> Self {
Self {
schema: schema.clone(),
options: options.clone(),
}
}
}

impl StreamingFormat<RecordBatch> for ArrowRecordBatchStreamFormat {
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, RecordBatch>,
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
let schema = self.schema.clone();
let options = self.options.clone();

let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.map(move |batch| {
let buf = BytesMut::new().writer();
let mut writer = StreamWriter::try_new_with_options(buf, &schema, options.clone())
.map_err(axum::Error::new)?;
writer.write(&batch).map_err(axum::Error::new)?;
writer.finish().map_err(axum::Error::new)?;
writer
.into_inner()
.map_err(axum::Error::new)
.map(|buf| buf.into_inner().freeze())
.map(axum::body::Bytes::from)
.map(Frame::data)
})
});

Box::pin(stream_bytes)
}

fn http_response_trailers(&self) -> Option<HeaderMap> {
let mut header_map = HeaderMap::new();
header_map.insert(
http::header::CONTENT_TYPE,
http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream"),
);
Some(header_map)
}
}

impl<'a> crate::StreamBodyAs<'a> {
pub fn arrow<S>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(ArrowRecordBatchStreamFormat::new(schema), stream)
}

pub fn arrow_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchStreamFormat::with_options(schema, options),
stream,
)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_client::*;
use crate::StreamBodyAs;
use arrow::array::*;
use arrow::datatypes::*;
use axum::{routing::*, Router};
use futures::stream;
use std::sync::Arc;

#[tokio::test]
async fn serialize_arrow_stream_format() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]));

fn create_test_batch(schema_ref: SchemaRef) -> Vec<RecordBatch> {
let vec_schema = schema_ref.clone();
(0i64..10i64)
.map(move |idx| {
RecordBatch::try_new(
vec_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
],
)
.unwrap()
})
.collect()
}

let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone())));

let app_schema = schema.clone();

let app = Router::new().route(
"/",
get(|| async move {
StreamBodyAs::new(
ArrowRecordBatchStreamFormat::new(app_schema.clone()),
test_stream,
)
}),
);

let client = TestClient::new(app).await;

let expected_proto_buf: Vec<u8> = create_test_batch(schema.clone())
.iter()
.flat_map(|batch| {
let mut writer = StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
writer.write(&batch).expect("write failed");
writer.finish().expect("writer failed");
writer.into_inner().expect("writer failed")
})
.collect();

let res = client.get("/").send().await.unwrap();
assert_eq!(
res.headers()
.get("content-type")
.and_then(|h| h.to_str().ok()),
Some("application/vnd.apache.arrow.stream")
);
let body = res.bytes().await.unwrap().to_vec();

assert_eq!(body, expected_proto_buf);
}
}
6 changes: 3 additions & 3 deletions src/csv_format.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::stream_format::StreamingFormat;
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;
Expand Down Expand Up @@ -159,7 +159,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;
use std::ops::Add;

#[tokio::test]
Expand Down
10 changes: 5 additions & 5 deletions src/json_formats.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::stream_format::StreamingFormat;
use crate::StreamFormatEnvelope;
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;
Expand Down Expand Up @@ -66,7 +66,7 @@ where
});

let prepend_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready({
Box::pin(futures::stream::once(futures::future::ready({
if let Some(envelope) = &self.envelope {
match serde_json::to_vec(&envelope.object) {
Ok(envelope_bytes) if envelope_bytes.len() > 1 => {
Expand Down Expand Up @@ -108,7 +108,7 @@ where
})));

let append_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready({
Box::pin(futures::stream::once(futures::future::ready({
if self.envelope.is_some() {
Ok::<_, axum::Error>(Frame::data(axum::body::Bytes::from(
JSON_ARRAY_ENVELOP_END_BYTES,
Expand Down Expand Up @@ -216,7 +216,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;

#[tokio::test]
async fn serialize_json_array_stream_format() {
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,10 @@ mod protobuf_format;
#[cfg(feature = "protobuf")]
pub use protobuf_format::ProtobufStreamFormat;

#[cfg(feature = "arrow")]
mod arrow_format;
#[cfg(feature = "arrow")]
pub use arrow_format::ArrowRecordBatchStreamFormat;

#[cfg(test)]
mod test_client;
6 changes: 3 additions & 3 deletions src/protobuf_format.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::stream_format::StreamingFormat;
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;

Expand Down Expand Up @@ -72,7 +72,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;
use prost::Message;

#[tokio::test]
Expand Down
2 changes: 1 addition & 1 deletion src/stream_body_as.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::stream_format::StreamingFormat;
use axum::body::{Body, HttpBody};
use axum::response::{IntoResponse, Response};
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use http::HeaderMap;
use http_body::Frame;
use std::fmt::Formatter;
Expand Down
2 changes: 1 addition & 1 deletion src/stream_format.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures_util::stream::BoxStream;
use futures::stream::BoxStream;
use http::HeaderMap;

pub trait StreamingFormat<T> {
Expand Down
6 changes: 3 additions & 3 deletions src/text_format.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::stream_format::StreamingFormat;
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;

Expand Down Expand Up @@ -55,7 +55,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;

#[tokio::test]
async fn serialize_text_stream_format() {
Expand Down

0 comments on commit dd2b7ac

Please sign in to comment.