diff --git a/src/daft-core/src/array/ops/concat.rs b/src/daft-core/src/array/ops/concat.rs index 36addc91fa..58906ae2a4 100644 --- a/src/daft-core/src/array/ops/concat.rs +++ b/src/daft-core/src/array/ops/concat.rs @@ -6,6 +6,61 @@ use common_error::{DaftError, DaftResult}; #[cfg(feature = "python")] use crate::array::pseudo_arrow::PseudoArrowArray; +macro_rules! impl_variable_length_concat { + ($fn_name:ident, $arrow_type:ty, $create_fn: ident) => { + fn $fn_name(arrays: &[&dyn arrow2::array::Array]) -> DaftResult> { + let mut num_rows: usize = 0; + let mut num_bytes: usize = 0; + let mut need_validity = false; + for arr in arrays { + let arr = arr.as_any().downcast_ref::<$arrow_type>().unwrap(); + + num_rows += arr.len(); + num_bytes += arr.values().len(); + need_validity |= arr.validity().map(|v| v.unset_bits() > 0).unwrap_or(false); + } + let mut offsets = arrow2::offset::Offsets::::with_capacity(num_rows); + + let mut validity = if need_validity { + Some(arrow2::bitmap::MutableBitmap::with_capacity(num_rows)) + } else { + None + }; + let mut buffer = Vec::::with_capacity(num_bytes); + + for arr in arrays { + let arr = arr.as_any().downcast_ref::<$arrow_type>().unwrap(); + offsets.try_extend_from_slice(arr.offsets(), 0, arr.len())?; + if let Some(ref mut bitmap) = validity { + if let Some(b) = arr.validity() { + bitmap.extend_from_bitmap(b); + } else { + bitmap.extend_constant(arr.len(), true); + } + } + buffer.extend_from_slice(arr.values().as_slice()); + } + let dtype = arrays.first().unwrap().data_type().clone(); + #[allow(unused_unsafe)] + let result_array = unsafe { + <$arrow_type>::$create_fn( + dtype, + offsets.into(), + buffer.into(), + validity.map(|v| v.into()), + ) + }?; + Ok(Box::new(result_array)) + } + }; +} +impl_variable_length_concat!( + utf8_concat, + arrow2::array::Utf8Array, + try_new_unchecked +); +impl_variable_length_concat!(binary_concat, arrow2::array::BinaryArray, try_new); + impl DataArray where T: DaftPhysicalType, @@ -41,6 +96,14 @@ where )); DataArray::new(field.clone(), cat_array) } + crate::DataType::Utf8 => { + let cat_array = utf8_concat(arrow_arrays.as_slice())?; + DataArray::new(field.clone(), cat_array) + } + crate::DataType::Binary => { + let cat_array = binary_concat(arrow_arrays.as_slice())?; + DataArray::new(field.clone(), cat_array) + } _ => { let cat_array: Box = arrow2::compute::concatenate::concatenate(arrow_arrays.as_slice())?; diff --git a/tests/benchmarks/test_concat.py b/tests/benchmarks/test_concat.py new file mode 100644 index 0000000000..f11b9dd5ef --- /dev/null +++ b/tests/benchmarks/test_concat.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import uuid + +from daft.series import Series + + +def test_string_concat(benchmark) -> None: + NUM_ROWS = 100_000 + data = Series.from_pylist([str(uuid.uuid4()) for _ in range(NUM_ROWS)]) + to_concat = [data] * 100 + + def bench_concat() -> Series: + return Series.concat(to_concat) + + benchmark(bench_concat)