From 2a8d6aaf86040694547ecba6363d9f521c551fe7 Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Fri, 22 Sep 2023 12:34:40 -0700 Subject: [PATCH] Add .str.split(). --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 18 +++- daft/series.py | 6 ++ src/daft-core/src/array/ops/full.rs | 3 +- src/daft-core/src/array/ops/utf8.rs | 123 ++++++++++++++++++++++- src/daft-core/src/python/series.rs | 4 + src/daft-core/src/series/ops/utf8.rs | 9 ++ src/daft-dsl/src/functions/utf8/mod.rs | 11 ++ src/daft-dsl/src/functions/utf8/split.rs | 50 +++++++++ src/daft-dsl/src/python.rs | 5 + tests/expressions/typing/test_str.py | 1 + tests/series/test_utf8_ops.py | 98 ++++++++++++++++++ tests/table/utf8/test_split.py | 28 ++++++ 13 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 src/daft-dsl/src/functions/utf8/split.rs create mode 100644 tests/table/utf8/test_split.py diff --git a/daft/daft.pyi b/daft/daft.pyi index c4d9eee604..1a6ee4f3e3 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -555,6 +555,7 @@ class PyExpr: def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... + def utf8_split(self, pattern: PyExpr) -> PyExpr: ... def utf8_length(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... @@ -617,6 +618,7 @@ class PySeries: def utf8_endswith(self, pattern: PySeries) -> PySeries: ... def utf8_startswith(self, pattern: PySeries) -> PySeries: ... def utf8_contains(self, pattern: PySeries) -> PySeries: ... + def utf8_split(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2eaf5eb290..d73f8b4709 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -572,7 +572,7 @@ def endswith(self, suffix: str | Expression) -> Expression: suffix_expr = Expression._to_expression(suffix) return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr)) - def startswith(self, prefix: str) -> Expression: + def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column Example: @@ -587,6 +587,22 @@ def startswith(self, prefix: str) -> Expression: prefix_expr = Expression._to_expression(prefix) return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr)) + def split(self, pattern: str | Expression) -> Expression: + """Splits each string on the given pattern, into one or more strings. + + Example: + >>> col("x").str.split(",") + >>> col("x").str.split(col("pattern")) + + Args: + pattern: The pattern on which each string should be split, or a column to pick such patterns from. + + Returns: + Expression: A List[Utf8] expression containing the string splits for each string in the column. + """ + pattern_expr = Expression._to_expression(pattern) + return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr)) + def concat(self, other: str) -> Expression: """Concatenates two string expressions together diff --git a/daft/series.py b/daft/series.py index d81196c391..2430dbf31d 100644 --- a/daft/series.py +++ b/daft/series.py @@ -534,6 +534,12 @@ def contains(self, pattern: Series) -> Series: assert self._series is not None and pattern._series is not None return Series._from_pyseries(self._series.utf8_contains(pattern._series)) + def split(self, pattern: Series) -> Series: + if not isinstance(pattern, Series): + raise ValueError(f"expected another Series but got {type(pattern)}") + assert self._series is not None and pattern._series is not None + return Series._from_pyseries(self._series.utf8_split(pattern._series)) + def concat(self, other: Series) -> Series: if not isinstance(other, Series): raise ValueError(f"expected another Series but got {type(other)}") diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index f39c904709..f2f86c5298 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -129,7 +129,8 @@ impl FullNull for ListArray { Self::new( Field::new(name, dtype.clone()), empty_flat_child, - OffsetsBuffer::try_from(repeat(0).take(length).collect::>()).unwrap(), + OffsetsBuffer::try_from(repeat(0).take(length + 1).collect::>()) + .unwrap(), Some(validity), ) } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 262559776f..56e91ec7d9 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,10 +1,72 @@ -use crate::datatypes::{BooleanArray, UInt64Array, Utf8Array}; +use crate::{ + array::ListArray, + datatypes::{BooleanArray, Field, UInt64Array, Utf8Array}, + DataType, Series, +}; use arrow2; use common_error::{DaftError, DaftResult}; use super::{as_arrow::AsArrow, full::FullNull}; +fn split_array_on_patterns<'a, T, U>( + arr_iter: T, + pattern_iter: U, + buffer_len: usize, + name: &str, +) -> DaftResult +where + T: arrow2::trusted_len::TrustedLen + Iterator>, + U: Iterator>, +{ + // This will overallocate by pattern_len * N_i, where N_i is the number of pattern occurences in the ith string in arr_iter. + let mut splits = arrow2::array::MutableUtf8Array::with_capacity(buffer_len); + // arr_iter implementing TrustedLen guarantees that the size_hint reports an accurate length. Specifically, we have that either + // (1) size_hint().0 == size_hint().1 == iterator length in the common case, or; + // (2) size_hint().0 == usize::MAX and size_hint().1 == None if the iterator is larger than usize::MAX. + // + // Since the iterator is guaranteed to be smaller than usize::MAX due to the UTF8Array i64 offset array constraint (no more than i64::MAX, + // which we assume to be smaller than usize::MAX), we should have that (1) always holds, so we can reliably unwrap the size hint upper bound + // and treat it as the iterator length. + let arr_len = arr_iter.size_hint().1.unwrap(); + let mut offsets = Vec::with_capacity(arr_len + 1); + offsets.push(0i64); + let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(arr_len); + for (val, pat) in arr_iter.zip(pattern_iter) { + let mut num_splits = 0i64; + match (val, pat) { + (Some(val), Some(pat)) => { + for split in val.split(pat) { + splits.push(Some(split)); + num_splits += 1; + } + validity.push(true); + } + (_, _) => { + validity.push(false); + } + } + let offset = offsets.last().unwrap() + num_splits; + offsets.push(offset); + } + // Shrink splits capacity to current length, since we will have overallocated if any of the patterns actually occurred in the strings. + splits.shrink_to_fit(); + let splits: arrow2::array::Utf8Array = splits.into(); + let offsets = arrow2::offset::OffsetsBuffer::try_from(offsets)?; + let validity: Option = match validity.unset_bits() { + 0 => None, + _ => Some(validity.into()), + }; + let flat_child = + Series::try_from(("splits", Box::new(splits) as Box))?; + Ok(ListArray::new( + Field::new(name, DataType::List(Box::new(DataType::Utf8))), + flat_child, + offsets, + validity, + )) +} + impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.ends_with(pat)) @@ -18,6 +80,65 @@ impl Utf8Array { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.contains(pat)) } + pub fn split(&self, pattern: &Utf8Array) -> DaftResult { + let self_arrow = self.as_arrow(); + let pattern_arrow = pattern.as_arrow(); + // Handle all-null cases. + if self_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + || pattern_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + { + return Ok(ListArray::full_null( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + std::cmp::max(self.len(), pattern.len()), + )); + // Handle empty cases. + } else if self.is_empty() || pattern.is_empty() { + return Ok(ListArray::empty( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + )); + } + let buffer_len = self_arrow.values().len(); + match (self.len(), pattern.len()) { + // Matching len case: + (self_len, pattern_len) if self_len == pattern_len => split_array_on_patterns( + self_arrow.into_iter(), + pattern_arrow.into_iter(), + buffer_len, + self.name(), + ), + // Broadcast pattern case: + (self_len, 1) => { + let pattern_scalar_value = pattern.get(0).unwrap(); + split_array_on_patterns( + self_arrow.into_iter(), + std::iter::repeat(Some(pattern_scalar_value)).take(self_len), + buffer_len, + self.name(), + ) + } + // Broadcast self case: + (1, pattern_len) => { + let self_scalar_value = self.get(0).unwrap(); + split_array_on_patterns( + std::iter::repeat(Some(self_scalar_value)).take(pattern_len), + pattern_arrow.into_iter(), + buffer_len * pattern_len, + self.name(), + ) + } + // Mismatched len case: + (self_len, pattern_len) => Err(DaftError::ComputeError(format!( + "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + ))), + } + } + pub fn length(&self) -> DaftResult { let self_arrow = self.as_arrow(); let arrow_result = self_arrow diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index a388585940..f3c8cb3fa0 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -247,6 +247,10 @@ impl PySeries { Ok(self.series.utf8_contains(&pattern.series)?.into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + Ok(self.series.utf8_split(&pattern.series)?.into()) + } + pub fn utf8_length(&self) -> PyResult { Ok(self.series.utf8_length()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index bea305d9fc..fb2539b64e 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -32,6 +32,15 @@ impl Series { } } + pub fn utf8_split(&self, pattern: &Series) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.split(pattern.utf8()?)?.into_series()), + dt => Err(DaftError::TypeError(format!( + "Split not implemented for type {dt}" + ))), + } + } + pub fn utf8_length(&self) -> DaftResult { match self.data_type() { DataType::Utf8 => Ok(self.utf8()?.length()?.into_series()), diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cd23c0883b..5c8901147e 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -1,12 +1,14 @@ mod contains; mod endswith; mod length; +mod split; mod startswith; use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use length::LengthEvaluator; use serde::{Deserialize, Serialize}; +use split::SplitEvaluator; use startswith::StartswithEvaluator; use crate::Expr; @@ -18,6 +20,7 @@ pub enum Utf8Expr { EndsWith, StartsWith, Contains, + Split, Length, } @@ -29,6 +32,7 @@ impl Utf8Expr { EndsWith => &EndswithEvaluator {}, StartsWith => &StartswithEvaluator {}, Contains => &ContainsEvaluator {}, + Split => &SplitEvaluator {}, Length => &LengthEvaluator {}, } } @@ -55,6 +59,13 @@ pub fn contains(data: &Expr, pattern: &Expr) -> Expr { } } +pub fn split(data: &Expr, pattern: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Split), + inputs: vec![data.clone(), pattern.clone()], + } +} + pub fn length(data: &Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Utf8(Utf8Expr::Length), diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs new file mode 100644 index 0000000000..8d2c238b70 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/split.rs @@ -0,0 +1,50 @@ +use crate::Expr; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct SplitEvaluator {} + +impl FunctionEvaluator for SplitEvaluator { + fn fn_name(&self) -> &'static str { + "split" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to split to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_split(pattern), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index c64afc4271..cb61044339 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -291,6 +291,11 @@ impl PyExpr { Ok(contains(&self.expr, &pattern.expr).into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + use crate::functions::utf8::split; + Ok(split(&self.expr, &pattern.expr).into()) + } + pub fn utf8_length(&self) -> PyResult { use crate::functions::utf8::length; Ok(length(&self.expr).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 99d85ac460..99ab473190 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -15,6 +15,7 @@ pytest.param(lambda data, pat: data.str.contains(pat), id="contains"), pytest.param(lambda data, pat: data.str.startswith(pat), id="startswith"), pytest.param(lambda data, pat: data.str.endswith(pat), id="endswith"), + pytest.param(lambda data, pat: data.str.endswith(pat), id="split"), pytest.param(lambda data, pat: data.str.concat(pat), id="concat"), ], ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index dafd294423..1019769501 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -102,6 +102,104 @@ def test_series_utf8_compare_invalid_inputs(funcname, bad_series) -> None: getattr(s.str, funcname)(bad_series) +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Single-character pattern. + (["a,b,c", "d,e", "f", "g,h"], [","], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + # Multi-character pattern. + (["abbcbbd", "bb", "bbe", "fbb"], ["bb"], [["a", "c", "d"], ["", ""], ["", "e"], ["f", ""]]), + # Empty pattern (character-splitting). + (["foo", "bar"], [""], [["", "f", "o", "o", ""], ["", "b", "a", "r", ""]]), + ], +) +def test_series_utf8_split_broadcast_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["a,b,c", "a:b:c", "a;b;c", "a.b.c"], [",", ":", ";", "."], [["a", "b", "c"]] * 4), + (["aabbccdd"] * 4, ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_multi_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["aabbccdd"], ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_broadcast_arr(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Mixed-in nulls. + (["a,b,c", None, "a;b;c", "a.b.c"], [",", ":", None, "."], [["a", "b", "c"], None, None, ["a", "b", "c"]]), + # All null data. + ([None] * 4, [","] * 4, [None] * 4), + # All null patterns. + (["foo"] * 4, [None] * 4, [None] * 4), + # Broadcasted null data. + ([None], [","] * 4, [None] * 4), + # Broadcasted null pattern. + (["foo"] * 4, [None], [None] * 4), + ], +) +def test_series_utf8_split_nulls(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Empty data. + ([[], [","] * 4, []]), + # Empty patterns. + ([["foo"] * 4, [], []]), + ], +) +def test_series_utf8_split_empty_arrs(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + "patterns", + [ + # Wrong number of elements, not broadcastable + Series.from_arrow(pa.array([",", "."], type=pa.string())), + # Bad input type + object(), + ], +) +def test_series_utf8_split_invalid_inputs(patterns) -> None: + s = Series.from_arrow(pa.array(["a,b,c", "d, e", "f"])) + with pytest.raises(ValueError): + s.str.split(patterns) + + def test_series_utf8_length() -> None: s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) result = s.str.length() diff --git a/tests/table/utf8/test_split.py b/tests/table/utf8/test_split.py new file mode 100644 index 0000000000..0da7735b1d --- /dev/null +++ b/tests/table/utf8/test_split.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import Table + + +@pytest.mark.parametrize( + ["expr", "data", "expected"], + [ + (col("col").str.split(","), ["a,b,c", "d,e", "f", "g,h"], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + ( + col("col").str.split(lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ( + col("col").str.split(col("emptystrings") + lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ], +) +def test_series_utf8_split_broadcast_pattern(expr, data, expected) -> None: + table = Table.from_pydict({"col": data, "emptystrings": ["", "", "", ""]}) + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"col": expected}