From f5568c9111b6caf3ce485331bb6e0753fbfb4253 Mon Sep 17 00:00:00 2001 From: Abdulla Abdurakhmanov Date: Sat, 13 Apr 2024 15:17:21 +0200 Subject: [PATCH] Fix arrow stream serialization (#41) --- src/arrow_format.rs | 113 +++++++++++++++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 27 deletions(-) diff --git a/src/arrow_format.rs b/src/arrow_format.rs index 3cd5ba3..00a92a8 100644 --- a/src/arrow_format.rs +++ b/src/arrow_format.rs @@ -2,12 +2,14 @@ use crate::stream_body_as::StreamBodyAsOptions; use crate::{StreamBodyAs, StreamingFormat}; use arrow::array::RecordBatch; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; +use arrow::error::ArrowError; +use arrow::ipc::writer::{write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use bytes::{BufMut, BytesMut}; use futures::stream::BoxStream; use futures::Stream; use futures::StreamExt; use http::HeaderMap; +use std::io::Write; use std::sync::Arc; pub struct ArrowRecordBatchIpcStreamFormat { @@ -33,23 +35,81 @@ impl StreamingFormat for ArrowRecordBatchIpcStreamFormat { &'a self, stream: BoxStream<'b, RecordBatch>, ) -> BoxStream<'b, Result> { - let schema = self.schema.clone(); - let options = self.options.clone(); - - Box::pin({ - stream.map(move |batch| { - let buf = BytesMut::new().writer(); - let mut writer = StreamWriter::try_new_with_options(buf, &schema, options.clone()) - .map_err(axum::Error::new)?; - writer.write(&batch).map_err(axum::Error::new)?; - writer.finish().map_err(axum::Error::new)?; - writer - .into_inner() - .map_err(axum::Error::new) - .map(|buf| buf.into_inner().freeze()) - .map(axum::body::Bytes::from) - }) - }) + fn write_batch( + ipc_data_gen: &mut IpcDataGenerator, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + batch: &RecordBatch, + prepend_schema: Option>, + ) -> Result { + let mut writer = BytesMut::new().writer(); + + if let Some(prepend_schema) = prepend_schema { + let encoded_message = ipc_data_gen.schema_to_bytes(&prepend_schema, write_options); + write_message(&mut writer, encoded_message, write_options)?; + } + + let (encoded_dictionaries, encoded_message) = ipc_data_gen + .encoded_batch(batch, dictionary_tracker, write_options) + .expect("StreamWriter is configured to not error on dictionary replacement"); + + for encoded_dictionary in encoded_dictionaries { + write_message(&mut writer, encoded_dictionary, write_options)?; + } + + write_message(&mut writer, encoded_message, write_options)?; + writer.flush()?; + Ok(writer.into_inner().freeze()) + } + + fn write_continuation() -> Result { + let mut writer = BytesMut::new().writer(); + const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + let total_len = 0_i32.to_le_bytes(); // Always zero in the stream format + + writer.write_all(&CONTINUATION_MARKER)?; + writer.write_all(&total_len[..])?; + writer.flush()?; + Ok(writer.into_inner().freeze()) + } + + let batch_schema = self.schema.clone(); + let batch_options = self.options.clone(); + + let ipc_data_gen = IpcDataGenerator::default(); + let dictionary_tracker: DictionaryTracker = DictionaryTracker::new(false); + + let batch_stream = Box::pin({ + stream.scan( + (ipc_data_gen, dictionary_tracker, 0), + move |(ipc_data_gen, dictionary_tracker, idx), batch| { + futures::future::ready({ + let prepend_schema = if *idx == 0 { + Some(batch_schema.clone()) + } else { + None + }; + *idx += 1; + let bytes = write_batch( + ipc_data_gen, + dictionary_tracker, + &batch_options, + &batch, + prepend_schema, + ) + .map_err(axum::Error::new); + Some(bytes) + }) + }, + ) + }); + + let append_stream: BoxStream> = + Box::pin(futures::stream::once(futures::future::ready({ + write_continuation().map_err(axum::Error::new) + }))); + + Box::pin(batch_stream.chain(append_stream)) } fn http_response_trailers(&self) -> Option { @@ -160,15 +220,13 @@ mod tests { let client = TestClient::new(app).await; - let expected_buf: Vec = create_test_batch(schema.clone()) - .iter() - .flat_map(|batch| { - let mut writer = StreamWriter::try_new(Vec::new(), &schema).expect("writer failed"); - writer.write(&batch).expect("write failed"); - writer.finish().expect("writer failed"); - writer.into_inner().expect("writer failed") - }) - .collect(); + let mut writer = + arrow::ipc::writer::StreamWriter::try_new(Vec::new(), &schema).expect("writer failed"); + for batch in create_test_batch(schema.clone()) { + writer.write(&batch).expect("write failed"); + } + writer.finish().expect("writer failed"); + let expected_buf = writer.into_inner().expect("writer failed"); let res = client.get("/").send().await.unwrap(); assert_eq!( @@ -179,6 +237,7 @@ mod tests { ); let body = res.bytes().await.unwrap().to_vec(); + assert_eq!(body.len(), expected_buf.len()); assert_eq!(body, expected_buf); } }