Skip to content

Commit

Permalink
[FEAT] Add find functionality for string (#2046)
Browse files Browse the repository at this point in the history
Resolves #1925 

Returns the first occurrence of the `substr` provided, `-1`, in case of
fails to find.
  • Loading branch information
murex971 authored Apr 1, 2024
1 parent 8a22679 commit f0aab5c
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 5 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ class PyExpr:
def utf8_capitalize(self) -> PyExpr: ...
def utf8_left(self, nchars: PyExpr) -> PyExpr: ...
def utf8_right(self, nchars: PyExpr) -> PyExpr: ...
def utf8_find(self, substr: PyExpr) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -1037,6 +1038,7 @@ class PySeries:
def utf8_capitalize(self) -> PySeries: ...
def utf8_left(self, nchars: PySeries) -> PySeries: ...
def utf8_right(self, nchars: PySeries) -> PySeries: ...
def utf8_find(self, substr: PySeries) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
16 changes: 16 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,22 @@ def right(self, nchars: int | Expression) -> Expression:
nchars_expr = Expression._to_expression(nchars)
return Expression._from_pyexpr(self._expr.utf8_right(nchars_expr._expr))

def find(self, substr: str | Expression) -> Expression:
"""Returns the index of the first occurrence of the substring in each string
.. NOTE::
The returned index is 0-based.
If the substring is not found, -1 is returned.
Example:
>>> col("x").str.find("foo")
Returns:
Expression: an Int64 expression with the index of the first occurrence of the substring in each string
"""
substr_expr = Expression._to_expression(substr)
return Expression._from_pyexpr(self._expr.utf8_find(substr_expr._expr))


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,12 @@ def right(self, nchars: Series) -> Series:
assert self._series is not None and nchars._series is not None
return Series._from_pyseries(self._series.utf8_right(nchars._series))

def find(self, substr: Series) -> Series:
if not isinstance(substr, Series):
raise ValueError(f"expected another Series but got {type(substr)}")
assert self._series is not None and substr._series is not None
return Series._from_pyseries(self._series.utf8_find(substr._series))


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.capitalize
Expression.str.left
Expression.str.right
Expression.str.find

.. _api-expressions-temporal:

Expand Down
6 changes: 6 additions & 0 deletions src/daft-core/src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ where
{
/// Creates a DataArray<T> of size `length` that is filled with all nulls.
fn full_null(name: &str, dtype: &DataType, length: usize) -> Self {
if dtype != &T::get_dtype() && !matches!(T::get_dtype(), DataType::Unknown) {
panic!(
"Cannot create DataArray from dtype: {dtype} with physical type: {}",
T::get_dtype()
);
}
let field = Field::new(name, dtype.clone());
#[cfg(feature = "python")]
if dtype.is_python() {
Expand Down
93 changes: 89 additions & 4 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
array::{DataArray, ListArray},
datatypes::{BooleanArray, DaftIntegerType, DaftNumericType, Field, UInt64Array, Utf8Array},
datatypes::{
BooleanArray, DaftIntegerType, DaftNumericType, Field, Int64Array, UInt64Array, Utf8Array,
},
DataType, Series,
};
use arrow2::{self, array::Array};
Expand Down Expand Up @@ -186,7 +188,7 @@ impl Utf8Array {
return match pattern_scalar_value {
None => Ok(BooleanArray::full_null(
self.name(),
self.data_type(),
&DataType::Boolean,
self.len(),
)),
Some(pattern_v) => {
Expand Down Expand Up @@ -479,6 +481,89 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn find(&self, substr: &Utf8Array) -> DaftResult<Int64Array> {
let self_arrow = self.as_arrow();
let substr_arrow = substr.as_arrow();
// Handle empty cases.
if self.is_empty() || substr.is_empty() {
return Ok(Int64Array::empty(self.name(), self.data_type()));
}
match (self.len(), substr.len()) {
// matching len case
(self_len, substr_len) if self_len == substr_len => {
let arrow_result = self_arrow
.iter()
.zip(substr_arrow.iter())
.map(|(val, substr)| match (val, substr) {
(Some(val), Some(substr)) => {
Some(val.find(substr).map(|pos| pos as i64).unwrap_or(-1))
}
_ => None,
})
.collect::<arrow2::array::Int64Array>();

Ok(Int64Array::from((self.name(), Box::new(arrow_result))))
}
// broadcast pattern case
(self_len, 1) => {
let substr_scalar_value = substr.get(0);
match substr_scalar_value {
None => Ok(Int64Array::full_null(
self.name(),
&DataType::Int64,
self_len,
)),
Some(substr_scalar_value) => {
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(
v.find(substr_scalar_value)
.map(|pos| pos as i64)
.unwrap_or(-1),
)
})
.collect::<arrow2::array::Int64Array>();

Ok(Int64Array::from((self.name(), Box::new(arrow_result))))
}
}
}
// broadcast self case
(1, substr_len) => {
let self_scalar_value = self.get(0);
match self_scalar_value {
None => Ok(Int64Array::full_null(
self.name(),
&DataType::Int64,
substr_len,
)),
Some(self_scalar_value) => {
let arrow_result = substr_arrow
.iter()
.map(|substr| {
let substr = substr?;
Some(
self_scalar_value
.find(substr)
.map(|pos| pos as i64)
.unwrap_or(-1),
)
})
.collect::<arrow2::array::Int64Array>();

Ok(Int64Array::from((self.name(), Box::new(arrow_result))))
}
}
}
// Mismatched len case:
(self_len, substr_len) => Err(DaftError::ComputeError(format!(
"lhs and rhs have different length arrays: {self_len} vs {substr_len}"
))),
}
}

pub fn left<I>(&self, n: &DataArray<I>) -> DaftResult<Utf8Array>
where
I: DaftIntegerType,
Expand Down Expand Up @@ -693,7 +778,7 @@ impl Utf8Array {
match other_scalar_value {
None => Ok(BooleanArray::full_null(
self.name(),
self.data_type(),
&DataType::Boolean,
self_len,
)),
Some(other_v) => {
Expand All @@ -714,7 +799,7 @@ impl Utf8Array {
match self_scalar_value {
None => Ok(BooleanArray::full_null(
self.name(),
self.data_type(),
&DataType::Boolean,
other_len,
)),
Some(self_v) => {
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ impl PySeries {
Ok(self.series.utf8_right(&nchars.series)?.into())
}

pub fn utf8_find(&self, substr: &Self) -> PyResult<Self> {
Ok(self.series.utf8_find(&substr.series)?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
9 changes: 9 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,13 @@ impl Series {
))),
}
}

pub fn utf8_find(&self, substr: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.find(substr.utf8()?)?.into_series()),
dt => Err(DaftError::TypeError(format!(
"Find not implemented for type {dt}"
))),
}
}
}
50 changes: 50 additions & 0 deletions src/daft-dsl/src/functions/utf8/find.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::Expr;
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct FindEvaluator {}

impl FunctionEvaluator for FindEvaluator {
fn fn_name(&self) -> &'static str {
"find"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data, substr] => match (data.to_field(schema), substr.to_field(schema)) {
(Ok(data_field), Ok(substr_field)) => {
match (&data_field.dtype, &substr_field.dtype) {
(DataType::Utf8, DataType::Utf8) => {
Ok(Field::new(data_field.name, DataType::Int64))
}
_ => Err(DaftError::TypeError(format!(
"Expects inputs to find to be utf8 and utf8, but received {data_field} and {substr_field}",
))),
}
}
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data, substr] => data.utf8_find(substr),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod contains;
mod endswith;
mod extract;
mod extract_all;
mod find;
mod left;
mod length;
mod lower;
Expand All @@ -20,6 +21,7 @@ use contains::ContainsEvaluator;
use endswith::EndswithEvaluator;
use extract::ExtractEvaluator;
use extract_all::ExtractAllEvaluator;
use find::FindEvaluator;
use left::LeftEvaluator;
use length::LengthEvaluator;
use lower::LowerEvaluator;
Expand Down Expand Up @@ -54,6 +56,7 @@ pub enum Utf8Expr {
Capitalize,
Left,
Right,
Find,
}

impl Utf8Expr {
Expand All @@ -77,6 +80,7 @@ impl Utf8Expr {
Capitalize => &CapitalizeEvaluator {},
Left => &LeftEvaluator {},
Right => &RightEvaluator {},
Find => &FindEvaluator {},
}
}
}
Expand Down Expand Up @@ -192,3 +196,10 @@ pub fn right(data: &Expr, count: &Expr) -> Expr {
inputs: vec![data.clone(), count.clone()],
}
}

pub fn find(data: &Expr, pattern: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Find),
inputs: vec![data.clone(), pattern.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ impl PyExpr {
Ok(right(&self.expr, &count.expr).into())
}

pub fn utf8_find(&self, substr: &Self) -> PyResult<Self> {
use crate::functions::utf8::find;
Ok(find(&self.expr, &substr.expr).into())
}

pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(&self.expr, raise_error_on_failure).into())
Expand Down
3 changes: 2 additions & 1 deletion tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
pytest.param(lambda data, pat: data.str.concat(pat), id="concat"),
pytest.param(lambda data, pat: data.str.extract(pat), id="extract"),
pytest.param(lambda data, pat: data.str.extract_all(pat), id="extract_all"),
pytest.param(lambda data, pat: data.str.find(pat), id="find"),
],
)
def test_str_compares(binary_data_fixture, op, request):
Expand Down Expand Up @@ -112,7 +113,7 @@ def test_str_capitalize():
"op",
[
pytest.param(lambda data, pat: data.str.left(pat), id="left"),
pytest.param(lambda data, pat: data.str.left(pat), id="right"),
pytest.param(lambda data, pat: data.str.right(pat), id="right"),
],
)
def test_str_left_right(binary_data_fixture, op, request):
Expand Down
46 changes: 46 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,49 @@ def test_series_utf8_extract_all_bad_pattern() -> None:
pattern = Series.from_arrow(pa.array(["["]))
with pytest.raises(ValueError):
s.str.extract_all(pattern)


@pytest.mark.parametrize(
["data", "substrs", "expected"],
[
# No broadcast
(["foo", "barbaz", "quux"], ["foo", "baz", "baz"], [0, 3, -1]),
# Broadcast substrs
(["foo", None, "quux"], ["foo"], [0, None, -1]),
# Broadcast data
(["foo"], ["foo", None, "baz"], [0, None, -1]),
# Broadcast null data
([None], ["foo", "bar", "baz"], [None, None, None]),
# Broadcast null substrs
(["foo", "barbaz", "quux"], [None], [None, None, None]),
# Empty data.
([[], ["foo", "bar"], []]),
# Empty substrs
([["foo"] * 4, [], []]),
# Mixed-in nulls
(["foo", None, "barbaz", "quux"], ["oo", "bar", "baz", None], [1, None, 3, None]),
# All null data.
([None] * 4, ["foo"] * 4, [None] * 4),
# All null substrs
(["foo"] * 4, [None] * 4, [None] * 4),
],
)
def test_series_utf8_find(data, substrs, expected) -> None:
s = Series.from_arrow(pa.array(data, type=pa.string()))
substrs = Series.from_arrow(pa.array(substrs, type=pa.string()))
result = s.str.find(substrs)
assert result.to_pylist() == expected


def test_series_utf8_find_mismatch_len() -> None:
s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"]))
substrs = Series.from_arrow(pa.array(["foo", "baz"], type=pa.string()))
with pytest.raises(ValueError):
s.str.find(substrs)


def test_series_utf8_find_bad_dtype() -> None:
s = Series.from_arrow(pa.array([1, 2, 3]))
substrs = Series.from_arrow(pa.array(["foo", "baz", "quux"]))
with pytest.raises(ValueError):
s.str.find(substrs)
Loading

0 comments on commit f0aab5c

Please sign in to comment.