Skip to content

Commit

Permalink
Fix arrow stream serialization (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdolence authored Apr 13, 2024
1 parent 0fa4e57 commit f5568c9
Showing 1 changed file with 86 additions and 27 deletions.
113 changes: 86 additions & 27 deletions src/arrow_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use crate::stream_body_as::StreamBodyAsOptions;
use crate::{StreamBodyAs, StreamingFormat};
use arrow::array::RecordBatch;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
use arrow::error::ArrowError;
use arrow::ipc::writer::{write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures::StreamExt;
use http::HeaderMap;
use std::io::Write;
use std::sync::Arc;

pub struct ArrowRecordBatchIpcStreamFormat {
Expand All @@ -33,23 +35,81 @@ impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
&'a self,
stream: BoxStream<'b, RecordBatch>,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
let schema = self.schema.clone();
let options = self.options.clone();

Box::pin({
stream.map(move |batch| {
let buf = BytesMut::new().writer();
let mut writer = StreamWriter::try_new_with_options(buf, &schema, options.clone())
.map_err(axum::Error::new)?;
writer.write(&batch).map_err(axum::Error::new)?;
writer.finish().map_err(axum::Error::new)?;
writer
.into_inner()
.map_err(axum::Error::new)
.map(|buf| buf.into_inner().freeze())
.map(axum::body::Bytes::from)
})
})
fn write_batch(
ipc_data_gen: &mut IpcDataGenerator,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
batch: &RecordBatch,
prepend_schema: Option<Arc<Schema>>,
) -> Result<axum::body::Bytes, ArrowError> {
let mut writer = BytesMut::new().writer();

if let Some(prepend_schema) = prepend_schema {
let encoded_message = ipc_data_gen.schema_to_bytes(&prepend_schema, write_options);
write_message(&mut writer, encoded_message, write_options)?;
}

let (encoded_dictionaries, encoded_message) = ipc_data_gen
.encoded_batch(batch, dictionary_tracker, write_options)
.expect("StreamWriter is configured to not error on dictionary replacement");

for encoded_dictionary in encoded_dictionaries {
write_message(&mut writer, encoded_dictionary, write_options)?;
}

write_message(&mut writer, encoded_message, write_options)?;
writer.flush()?;
Ok(writer.into_inner().freeze())
}

fn write_continuation() -> Result<axum::body::Bytes, ArrowError> {
let mut writer = BytesMut::new().writer();
const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
let total_len = 0_i32.to_le_bytes(); // Always zero in the stream format

writer.write_all(&CONTINUATION_MARKER)?;
writer.write_all(&total_len[..])?;
writer.flush()?;
Ok(writer.into_inner().freeze())
}

let batch_schema = self.schema.clone();
let batch_options = self.options.clone();

let ipc_data_gen = IpcDataGenerator::default();
let dictionary_tracker: DictionaryTracker = DictionaryTracker::new(false);

let batch_stream = Box::pin({
stream.scan(
(ipc_data_gen, dictionary_tracker, 0),
move |(ipc_data_gen, dictionary_tracker, idx), batch| {
futures::future::ready({
let prepend_schema = if *idx == 0 {
Some(batch_schema.clone())
} else {
None
};
*idx += 1;
let bytes = write_batch(
ipc_data_gen,
dictionary_tracker,
&batch_options,
&batch,
prepend_schema,
)
.map_err(axum::Error::new);
Some(bytes)
})
},
)
});

let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
Box::pin(futures::stream::once(futures::future::ready({
write_continuation().map_err(axum::Error::new)
})));

Box::pin(batch_stream.chain(append_stream))
}

fn http_response_trailers(&self) -> Option<HeaderMap> {
Expand Down Expand Up @@ -160,15 +220,13 @@ mod tests {

let client = TestClient::new(app).await;

let expected_buf: Vec<u8> = create_test_batch(schema.clone())
.iter()
.flat_map(|batch| {
let mut writer = StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
writer.write(&batch).expect("write failed");
writer.finish().expect("writer failed");
writer.into_inner().expect("writer failed")
})
.collect();
let mut writer =
arrow::ipc::writer::StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
for batch in create_test_batch(schema.clone()) {
writer.write(&batch).expect("write failed");
}
writer.finish().expect("writer failed");
let expected_buf = writer.into_inner().expect("writer failed");

let res = client.get("/").send().await.unwrap();
assert_eq!(
Expand All @@ -179,6 +237,7 @@ mod tests {
);
let body = res.bytes().await.unwrap().to_vec();

assert_eq!(body.len(), expected_buf.len());
assert_eq!(body, expected_buf);
}
}

0 comments on commit f5568c9

Please sign in to comment.