diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 6713597897..09ebb0876e 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,10 +1,18 @@ -use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index}; +use arrow2::{ + array::{Array, Utf8Array}, + bitmap::utils::SlicesIterator, + offset::OffsetsBuffer, + types::Index, +}; use common_error::DaftResult; use super::{as_arrow::AsArrow, DaftConcatAggable}; -use crate::array::{ - growable::{make_growable, Growable}, - ListArray, +use crate::{ + array::{ + growable::{make_growable, Growable}, + DataArray, ListArray, + }, + prelude::Utf8Type, }; #[cfg(feature = "python")] @@ -146,6 +154,67 @@ impl DaftConcatAggable for ListArray { } } +impl DaftConcatAggable for DataArray { + type Output = DaftResult; + + fn concat(&self) -> Self::Output { + let new_validity = match self.validity() { + Some(validity) if validity.unset_bits() == self.len() => { + Some(arrow2::bitmap::Bitmap::from(vec![false])) + } + _ => None, + }; + + let arrow_array = self.as_arrow(); + let new_offsets = OffsetsBuffer::::try_from(vec![0, *arrow_array.offsets().last()])?; + let output = Utf8Array::new( + arrow_array.data_type().clone(), + new_offsets, + arrow_array.values().clone(), + new_validity, + ); + + let result_box = Box::new(output); + DataArray::new(self.field().clone().into(), result_box) + } + + fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { + let arrow_array = self.as_arrow(); + let concat_per_group = if arrow_array.null_count() > 0 { + Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| { + let to_concat = g + .iter() + .filter_map(|index| { + let idx = *index as usize; + arrow_array.get(idx) + }) + .collect::>(); + if to_concat.is_empty() { + None + } else { + Some(to_concat.concat()) + } + }))) + } else { + Box::new(Utf8Array::from_trusted_len_values_iter(groups.iter().map( + |g| { + g.iter() + .map(|index| { + let idx = *index as usize; + arrow_array.value(idx) + }) + .collect::() + }, + ))) + }; + + Ok(DataArray::from(( + self.field.name.as_ref(), + concat_per_group, + ))) + } +} + #[cfg(test)] mod test { use std::iter::repeat; diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 44a4c10348..353c6ca25d 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -244,8 +244,17 @@ impl Series { None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), } } + DataType::Utf8 => { + let downcasted = self.downcast::()?; + match groups { + Some(groups) => { + Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) + } + None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), + } + } _ => Err(DaftError::TypeError(format!( - "concat aggregation is only valid for List or Python types, got {}", + "concat aggregation is only valid for List, Python types, or Utf8, got {}", self.data_type() ))), } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index affb5f08e3..f8c5deb247 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -390,6 +390,7 @@ impl AggExpr { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), + DataType::Utf8 => Ok(field), #[cfg(feature = "python")] DataType::Python => Ok(field), _ => Err(DaftError::TypeError(format!( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index fa7a26b3e4..01749a1cdb 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -874,3 +874,53 @@ def test_groupby_struct(dtype) -> None: expected = [[0, 1, 4], [2, 6], [3, 5]] for lt in expected: assert lt in res["b"] + + +def test_agg_concat_on_string() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + assert res["a"] == ["the quick brown fox"] + + +def test_agg_concat_on_string_groupby() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the brown", " quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + expected = ["the quick fox"] + assert res["a"] == expected + + +def test_agg_concat_on_string_groupby_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the", " quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) + res = df3.agg(col("a").agg_concat()).to_pydict() + print(res) + expected = [None] + assert res["a"] == expected + assert len(res["a"]) == 1 + + +def test_agg_concat_on_string_groupby_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = [None, None] + assert res["a"] == expected + assert len(res["a"]) == len(expected)