Skip to content

Commit

Permalink
[FEAT] Add .str.split() API for splitting string columns. (#1409)
Browse files Browse the repository at this point in the history
This PR adds an `Expression.str.split()` API for splitting strings in
string columns on a pattern.

Closes #1388
  • Loading branch information
clarkzinzow authored Sep 28, 2023
1 parent 32fd4f6 commit 069432d
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 3 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ class PyExpr:
def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ...
def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ...
def utf8_contains(self, pattern: PyExpr) -> PyExpr: ...
def utf8_split(self, pattern: PyExpr) -> PyExpr: ...
def utf8_length(self) -> PyExpr: ...
def image_decode(self) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
Expand Down Expand Up @@ -617,6 +618,7 @@ class PySeries:
def utf8_endswith(self, pattern: PySeries) -> PySeries: ...
def utf8_startswith(self, pattern: PySeries) -> PySeries: ...
def utf8_contains(self, pattern: PySeries) -> PySeries: ...
def utf8_split(self, pattern: PySeries) -> PySeries: ...
def utf8_length(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
Expand Down
18 changes: 17 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def endswith(self, suffix: str | Expression) -> Expression:
suffix_expr = Expression._to_expression(suffix)
return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr))

def startswith(self, prefix: str) -> Expression:
def startswith(self, prefix: str | Expression) -> Expression:
"""Checks whether each string starts with the given pattern in a string column
Example:
Expand All @@ -587,6 +587,22 @@ def startswith(self, prefix: str) -> Expression:
prefix_expr = Expression._to_expression(prefix)
return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr))

def split(self, pattern: str | Expression) -> Expression:
"""Splits each string on the given pattern, into one or more strings.
Example:
>>> col("x").str.split(",")
>>> col("x").str.split(col("pattern"))
Args:
pattern: The pattern on which each string should be split, or a column to pick such patterns from.
Returns:
Expression: A List[Utf8] expression containing the string splits for each string in the column.
"""
pattern_expr = Expression._to_expression(pattern)
return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr))

def concat(self, other: str) -> Expression:
"""Concatenates two string expressions together
Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ def contains(self, pattern: Series) -> Series:
assert self._series is not None and pattern._series is not None
return Series._from_pyseries(self._series.utf8_contains(pattern._series))

def split(self, pattern: Series) -> Series:
if not isinstance(pattern, Series):
raise ValueError(f"expected another Series but got {type(pattern)}")
assert self._series is not None and pattern._series is not None
return Series._from_pyseries(self._series.utf8_split(pattern._series))

def concat(self, other: Series) -> Series:
if not isinstance(other, Series):
raise ValueError(f"expected another Series but got {type(other)}")
Expand Down
3 changes: 2 additions & 1 deletion src/daft-core/src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ impl FullNull for ListArray {
Self::new(
Field::new(name, dtype.clone()),
empty_flat_child,
OffsetsBuffer::try_from(repeat(0).take(length).collect::<Vec<_>>()).unwrap(),
OffsetsBuffer::try_from(repeat(0).take(length + 1).collect::<Vec<_>>())
.unwrap(),
Some(validity),
)
}
Expand Down
117 changes: 116 additions & 1 deletion src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
use crate::datatypes::{BooleanArray, UInt64Array, Utf8Array};
use crate::{
array::ListArray,
datatypes::{BooleanArray, Field, UInt64Array, Utf8Array},
DataType, Series,
};
use arrow2;

use common_error::{DaftError, DaftResult};

use super::{as_arrow::AsArrow, full::FullNull};

fn split_array_on_patterns<'a, T, U>(
arr_iter: T,
pattern_iter: U,
buffer_len: usize,
name: &str,
) -> DaftResult<ListArray>
where
T: arrow2::trusted_len::TrustedLen + Iterator<Item = Option<&'a str>>,
U: Iterator<Item = Option<&'a str>>,
{
// This will overallocate by pattern_len * N_i, where N_i is the number of pattern occurences in the ith string in arr_iter.
let mut splits = arrow2::array::MutableUtf8Array::with_capacity(buffer_len);
// arr_iter implements TrustedLen, so we can always use size_hint().1 as the exact length of the iterator. The only
// time this would fail is if the length of the iterator exceeds usize::MAX, which should never happen for an i64
// offset array, since the array length can't exceed i64::MAX on 64-bit machines.
let arr_len = arr_iter.size_hint().1.unwrap();
let mut offsets = arrow2::offset::Offsets::new();
let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(arr_len);
for (val, pat) in arr_iter.zip(pattern_iter) {
let mut num_splits = 0i64;
match (val, pat) {
(Some(val), Some(pat)) => {
for split in val.split(pat) {
splits.push(Some(split));
num_splits += 1;
}
validity.push(true);
}
(_, _) => {
validity.push(false);
}
}
offsets.try_push(num_splits)?;
}
// Shrink splits capacity to current length, since we will have overallocated if any of the patterns actually occurred in the strings.
splits.shrink_to_fit();
let splits: arrow2::array::Utf8Array<i64> = splits.into();
let offsets: arrow2::offset::OffsetsBuffer<i64> = offsets.into();
let validity: Option<arrow2::bitmap::Bitmap> = match validity.unset_bits() {
0 => None,
_ => Some(validity.into()),
};
let flat_child =
Series::try_from(("splits", Box::new(splits) as Box<dyn arrow2::array::Array>))?;
Ok(ListArray::new(
Field::new(name, DataType::List(Box::new(DataType::Utf8))),
flat_child,
offsets,
validity,
))
}

impl Utf8Array {
pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.ends_with(pat))
Expand All @@ -18,6 +74,65 @@ impl Utf8Array {
self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.contains(pat))
}

pub fn split(&self, pattern: &Utf8Array) -> DaftResult<ListArray> {
let self_arrow = self.as_arrow();
let pattern_arrow = pattern.as_arrow();
// Handle all-null cases.
if self_arrow
.validity()
.map_or(false, |v| v.unset_bits() == v.len())
|| pattern_arrow
.validity()
.map_or(false, |v| v.unset_bits() == v.len())
{
return Ok(ListArray::full_null(
self.name(),
&DataType::List(Box::new(DataType::Utf8)),
std::cmp::max(self.len(), pattern.len()),
));
// Handle empty cases.
} else if self.is_empty() || pattern.is_empty() {
return Ok(ListArray::empty(
self.name(),
&DataType::List(Box::new(DataType::Utf8)),
));
}
let buffer_len = self_arrow.values().len();
match (self.len(), pattern.len()) {
// Matching len case:
(self_len, pattern_len) if self_len == pattern_len => split_array_on_patterns(
self_arrow.into_iter(),
pattern_arrow.into_iter(),
buffer_len,
self.name(),
),
// Broadcast pattern case:
(self_len, 1) => {
let pattern_scalar_value = pattern.get(0).unwrap();
split_array_on_patterns(
self_arrow.into_iter(),
std::iter::repeat(Some(pattern_scalar_value)).take(self_len),
buffer_len,
self.name(),
)
}
// Broadcast self case:
(1, pattern_len) => {
let self_scalar_value = self.get(0).unwrap();
split_array_on_patterns(
std::iter::repeat(Some(self_scalar_value)).take(pattern_len),
pattern_arrow.into_iter(),
buffer_len * pattern_len,
self.name(),
)
}
// Mismatched len case:
(self_len, pattern_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {pattern_len}"
))),
}
}

pub fn length(&self) -> DaftResult<UInt64Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ impl PySeries {
Ok(self.series.utf8_contains(&pattern.series)?.into())
}

pub fn utf8_split(&self, pattern: &Self) -> PyResult<Self> {
Ok(self.series.utf8_split(&pattern.series)?.into())
}

pub fn utf8_length(&self) -> PyResult<Self> {
Ok(self.series.utf8_length()?.into())
}
Expand Down
9 changes: 9 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ impl Series {
}
}

pub fn utf8_split(&self, pattern: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.split(pattern.utf8()?)?.into_series()),
dt => Err(DaftError::TypeError(format!(
"Split not implemented for type {dt}"
))),
}
}

pub fn utf8_length(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.length()?.into_series()),
Expand Down
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
mod contains;
mod endswith;
mod length;
mod split;
mod startswith;

use contains::ContainsEvaluator;
use endswith::EndswithEvaluator;
use length::LengthEvaluator;
use serde::{Deserialize, Serialize};
use split::SplitEvaluator;
use startswith::StartswithEvaluator;

use crate::Expr;
Expand All @@ -18,6 +20,7 @@ pub enum Utf8Expr {
EndsWith,
StartsWith,
Contains,
Split,
Length,
}

Expand All @@ -29,6 +32,7 @@ impl Utf8Expr {
EndsWith => &EndswithEvaluator {},
StartsWith => &StartswithEvaluator {},
Contains => &ContainsEvaluator {},
Split => &SplitEvaluator {},
Length => &LengthEvaluator {},
}
}
Expand All @@ -55,6 +59,13 @@ pub fn contains(data: &Expr, pattern: &Expr) -> Expr {
}
}

pub fn split(data: &Expr, pattern: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Split),
inputs: vec![data.clone(), pattern.clone()],
}
}

pub fn length(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Length),
Expand Down
50 changes: 50 additions & 0 deletions src/daft-dsl/src/functions/utf8/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::Expr;
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct SplitEvaluator {}

impl FunctionEvaluator for SplitEvaluator {
fn fn_name(&self) -> &'static str {
"split"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) {
(Ok(data_field), Ok(pattern_field)) => {
match (&data_field.dtype, &pattern_field.dtype) {
(DataType::Utf8, DataType::Utf8) => {
Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8))))
}
_ => Err(DaftError::TypeError(format!(
"Expects inputs to split to be utf8, but received {data_field} and {pattern_field}",
))),
}
}
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data, pattern] => data.utf8_split(pattern),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ impl PyExpr {
Ok(contains(&self.expr, &pattern.expr).into())
}

pub fn utf8_split(&self, pattern: &Self) -> PyResult<Self> {
use crate::functions::utf8::split;
Ok(split(&self.expr, &pattern.expr).into())
}

pub fn utf8_length(&self) -> PyResult<Self> {
use crate::functions::utf8::length;
Ok(length(&self.expr).into())
Expand Down
1 change: 1 addition & 0 deletions tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
pytest.param(lambda data, pat: data.str.contains(pat), id="contains"),
pytest.param(lambda data, pat: data.str.startswith(pat), id="startswith"),
pytest.param(lambda data, pat: data.str.endswith(pat), id="endswith"),
pytest.param(lambda data, pat: data.str.endswith(pat), id="split"),
pytest.param(lambda data, pat: data.str.concat(pat), id="concat"),
],
)
Expand Down
Loading

0 comments on commit 069432d

Please sign in to comment.