From f0aab5c9dda3010aa017c7eb9cd692e5abb87f40 Mon Sep 17 00:00:00 2001 From: Nupur Agrawal Date: Tue, 2 Apr 2024 00:10:08 +0530 Subject: [PATCH] [FEAT] Add find functionality for string (#2046) Resolves #1925 Returns the first occurrence of the `substr` provided, `-1`, in case of fails to find. --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 16 +++++ daft/series.py | 6 ++ docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/full.rs | 6 ++ src/daft-core/src/array/ops/utf8.rs | 93 +++++++++++++++++++++++-- src/daft-core/src/python/series.rs | 4 ++ src/daft-core/src/series/ops/utf8.rs | 9 +++ src/daft-dsl/src/functions/utf8/find.rs | 50 +++++++++++++ src/daft-dsl/src/functions/utf8/mod.rs | 11 +++ src/daft-dsl/src/python.rs | 5 ++ tests/expressions/typing/test_str.py | 3 +- tests/series/test_utf8_ops.py | 46 ++++++++++++ tests/table/utf8/test_find.py | 26 +++++++ 14 files changed, 273 insertions(+), 5 deletions(-) create mode 100644 src/daft-dsl/src/functions/utf8/find.rs create mode 100644 tests/table/utf8/test_find.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 93dbbfa09e..264cbaeee0 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -945,6 +945,7 @@ class PyExpr: def utf8_capitalize(self) -> PyExpr: ... def utf8_left(self, nchars: PyExpr) -> PyExpr: ... def utf8_right(self, nchars: PyExpr) -> PyExpr: ... + def utf8_find(self, substr: PyExpr) -> PyExpr: ... def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... def image_resize(self, w: int, h: int) -> PyExpr: ... @@ -1037,6 +1038,7 @@ class PySeries: def utf8_capitalize(self) -> PySeries: ... def utf8_left(self, nchars: PySeries) -> PySeries: ... def utf8_right(self, nchars: PySeries) -> PySeries: ... + def utf8_find(self, substr: PySeries) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 80011ee91c..7684f3aee7 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1011,6 +1011,22 @@ def right(self, nchars: int | Expression) -> Expression: nchars_expr = Expression._to_expression(nchars) return Expression._from_pyexpr(self._expr.utf8_right(nchars_expr._expr)) + def find(self, substr: str | Expression) -> Expression: + """Returns the index of the first occurrence of the substring in each string + + .. NOTE:: + The returned index is 0-based. + If the substring is not found, -1 is returned. + + Example: + >>> col("x").str.find("foo") + + Returns: + Expression: an Int64 expression with the index of the first occurrence of the substring in each string + """ + substr_expr = Expression._to_expression(substr) + return Expression._from_pyexpr(self._expr.utf8_find(substr_expr._expr)) + class ExpressionListNamespace(ExpressionNamespace): def join(self, delimiter: str | Expression) -> Expression: diff --git a/daft/series.py b/daft/series.py index 48ee0d21ba..17563ba825 100644 --- a/daft/series.py +++ b/daft/series.py @@ -659,6 +659,12 @@ def right(self, nchars: Series) -> Series: assert self._series is not None and nchars._series is not None return Series._from_pyseries(self._series.utf8_right(nchars._series)) + def find(self, substr: Series) -> Series: + if not isinstance(substr, Series): + raise ValueError(f"expected another Series but got {type(substr)}") + assert self._series is not None and substr._series is not None + return Series._from_pyseries(self._series.utf8_find(substr._series)) + class SeriesDateNamespace(SeriesNamespace): def date(self) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index ae65057cfd..03711ba383 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -120,6 +120,7 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.capitalize Expression.str.left Expression.str.right + Expression.str.find .. _api-expressions-temporal: diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index 153a3e0c72..7147a892e0 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -25,6 +25,12 @@ where { /// Creates a DataArray of size `length` that is filled with all nulls. fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { + if dtype != &T::get_dtype() && !matches!(T::get_dtype(), DataType::Unknown) { + panic!( + "Cannot create DataArray from dtype: {dtype} with physical type: {}", + T::get_dtype() + ); + } let field = Field::new(name, dtype.clone()); #[cfg(feature = "python")] if dtype.is_python() { diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index ced1912e84..c216863549 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,6 +1,8 @@ use crate::{ array::{DataArray, ListArray}, - datatypes::{BooleanArray, DaftIntegerType, DaftNumericType, Field, UInt64Array, Utf8Array}, + datatypes::{ + BooleanArray, DaftIntegerType, DaftNumericType, Field, Int64Array, UInt64Array, Utf8Array, + }, DataType, Series, }; use arrow2::{self, array::Array}; @@ -186,7 +188,7 @@ impl Utf8Array { return match pattern_scalar_value { None => Ok(BooleanArray::full_null( self.name(), - self.data_type(), + &DataType::Boolean, self.len(), )), Some(pattern_v) => { @@ -479,6 +481,89 @@ impl Utf8Array { Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } + pub fn find(&self, substr: &Utf8Array) -> DaftResult { + let self_arrow = self.as_arrow(); + let substr_arrow = substr.as_arrow(); + // Handle empty cases. + if self.is_empty() || substr.is_empty() { + return Ok(Int64Array::empty(self.name(), self.data_type())); + } + match (self.len(), substr.len()) { + // matching len case + (self_len, substr_len) if self_len == substr_len => { + let arrow_result = self_arrow + .iter() + .zip(substr_arrow.iter()) + .map(|(val, substr)| match (val, substr) { + (Some(val), Some(substr)) => { + Some(val.find(substr).map(|pos| pos as i64).unwrap_or(-1)) + } + _ => None, + }) + .collect::(); + + Ok(Int64Array::from((self.name(), Box::new(arrow_result)))) + } + // broadcast pattern case + (self_len, 1) => { + let substr_scalar_value = substr.get(0); + match substr_scalar_value { + None => Ok(Int64Array::full_null( + self.name(), + &DataType::Int64, + self_len, + )), + Some(substr_scalar_value) => { + let arrow_result = self_arrow + .iter() + .map(|val| { + let v = val?; + Some( + v.find(substr_scalar_value) + .map(|pos| pos as i64) + .unwrap_or(-1), + ) + }) + .collect::(); + + Ok(Int64Array::from((self.name(), Box::new(arrow_result)))) + } + } + } + // broadcast self case + (1, substr_len) => { + let self_scalar_value = self.get(0); + match self_scalar_value { + None => Ok(Int64Array::full_null( + self.name(), + &DataType::Int64, + substr_len, + )), + Some(self_scalar_value) => { + let arrow_result = substr_arrow + .iter() + .map(|substr| { + let substr = substr?; + Some( + self_scalar_value + .find(substr) + .map(|pos| pos as i64) + .unwrap_or(-1), + ) + }) + .collect::(); + + Ok(Int64Array::from((self.name(), Box::new(arrow_result)))) + } + } + } + // Mismatched len case: + (self_len, substr_len) => Err(DaftError::ComputeError(format!( + "lhs and rhs have different length arrays: {self_len} vs {substr_len}" + ))), + } + } + pub fn left(&self, n: &DataArray) -> DaftResult where I: DaftIntegerType, @@ -693,7 +778,7 @@ impl Utf8Array { match other_scalar_value { None => Ok(BooleanArray::full_null( self.name(), - self.data_type(), + &DataType::Boolean, self_len, )), Some(other_v) => { @@ -714,7 +799,7 @@ impl Utf8Array { match self_scalar_value { None => Ok(BooleanArray::full_null( self.name(), - self.data_type(), + &DataType::Boolean, other_len, )), Some(self_v) => { diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index b006923a37..4e79e6cdf8 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -325,6 +325,10 @@ impl PySeries { Ok(self.series.utf8_right(&nchars.series)?.into()) } + pub fn utf8_find(&self, substr: &Self) -> PyResult { + Ok(self.series.utf8_find(&substr.series)?.into()) + } + pub fn is_nan(&self) -> PyResult { Ok(self.series.is_nan()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index 94666ad141..4ae865c992 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -212,4 +212,13 @@ impl Series { ))), } } + + pub fn utf8_find(&self, substr: &Series) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.find(substr.utf8()?)?.into_series()), + dt => Err(DaftError::TypeError(format!( + "Find not implemented for type {dt}" + ))), + } + } } diff --git a/src/daft-dsl/src/functions/utf8/find.rs b/src/daft-dsl/src/functions/utf8/find.rs new file mode 100644 index 0000000000..de33487493 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/find.rs @@ -0,0 +1,50 @@ +use crate::Expr; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct FindEvaluator {} + +impl FunctionEvaluator for FindEvaluator { + fn fn_name(&self) -> &'static str { + "find" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data, substr] => match (data.to_field(schema), substr.to_field(schema)) { + (Ok(data_field), Ok(substr_field)) => { + match (&data_field.dtype, &substr_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::Int64)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to find to be utf8 and utf8, but received {data_field} and {substr_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data, substr] => data.utf8_find(substr), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index 5cd36b17af..9ad9d8263e 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -3,6 +3,7 @@ mod contains; mod endswith; mod extract; mod extract_all; +mod find; mod left; mod length; mod lower; @@ -20,6 +21,7 @@ use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use extract::ExtractEvaluator; use extract_all::ExtractAllEvaluator; +use find::FindEvaluator; use left::LeftEvaluator; use length::LengthEvaluator; use lower::LowerEvaluator; @@ -54,6 +56,7 @@ pub enum Utf8Expr { Capitalize, Left, Right, + Find, } impl Utf8Expr { @@ -77,6 +80,7 @@ impl Utf8Expr { Capitalize => &CapitalizeEvaluator {}, Left => &LeftEvaluator {}, Right => &RightEvaluator {}, + Find => &FindEvaluator {}, } } } @@ -192,3 +196,10 @@ pub fn right(data: &Expr, count: &Expr) -> Expr { inputs: vec![data.clone(), count.clone()], } } + +pub fn find(data: &Expr, pattern: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Find), + inputs: vec![data.clone(), pattern.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a2898bf20d..1fdf2d0f2d 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -472,6 +472,11 @@ impl PyExpr { Ok(right(&self.expr, &count.expr).into()) } + pub fn utf8_find(&self, substr: &Self) -> PyResult { + use crate::functions::utf8::find; + Ok(find(&self.expr, &substr.expr).into()) + } + pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult { use crate::functions::image::decode; Ok(decode(&self.expr, raise_error_on_failure).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index d9b96e799c..aa9a36654e 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -20,6 +20,7 @@ pytest.param(lambda data, pat: data.str.concat(pat), id="concat"), pytest.param(lambda data, pat: data.str.extract(pat), id="extract"), pytest.param(lambda data, pat: data.str.extract_all(pat), id="extract_all"), + pytest.param(lambda data, pat: data.str.find(pat), id="find"), ], ) def test_str_compares(binary_data_fixture, op, request): @@ -112,7 +113,7 @@ def test_str_capitalize(): "op", [ pytest.param(lambda data, pat: data.str.left(pat), id="left"), - pytest.param(lambda data, pat: data.str.left(pat), id="right"), + pytest.param(lambda data, pat: data.str.right(pat), id="right"), ], ) def test_str_left_right(binary_data_fixture, op, request): diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index cbad3f2648..a826461a04 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -582,3 +582,49 @@ def test_series_utf8_extract_all_bad_pattern() -> None: pattern = Series.from_arrow(pa.array(["["])) with pytest.raises(ValueError): s.str.extract_all(pattern) + + +@pytest.mark.parametrize( + ["data", "substrs", "expected"], + [ + # No broadcast + (["foo", "barbaz", "quux"], ["foo", "baz", "baz"], [0, 3, -1]), + # Broadcast substrs + (["foo", None, "quux"], ["foo"], [0, None, -1]), + # Broadcast data + (["foo"], ["foo", None, "baz"], [0, None, -1]), + # Broadcast null data + ([None], ["foo", "bar", "baz"], [None, None, None]), + # Broadcast null substrs + (["foo", "barbaz", "quux"], [None], [None, None, None]), + # Empty data. + ([[], ["foo", "bar"], []]), + # Empty substrs + ([["foo"] * 4, [], []]), + # Mixed-in nulls + (["foo", None, "barbaz", "quux"], ["oo", "bar", "baz", None], [1, None, 3, None]), + # All null data. + ([None] * 4, ["foo"] * 4, [None] * 4), + # All null substrs + (["foo"] * 4, [None] * 4, [None] * 4), + ], +) +def test_series_utf8_find(data, substrs, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + substrs = Series.from_arrow(pa.array(substrs, type=pa.string())) + result = s.str.find(substrs) + assert result.to_pylist() == expected + + +def test_series_utf8_find_mismatch_len() -> None: + s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) + substrs = Series.from_arrow(pa.array(["foo", "baz"], type=pa.string())) + with pytest.raises(ValueError): + s.str.find(substrs) + + +def test_series_utf8_find_bad_dtype() -> None: + s = Series.from_arrow(pa.array([1, 2, 3])) + substrs = Series.from_arrow(pa.array(["foo", "baz", "quux"])) + with pytest.raises(ValueError): + s.str.find(substrs) diff --git a/tests/table/utf8/test_find.py b/tests/table/utf8/test_find.py new file mode 100644 index 0000000000..fccbc382aa --- /dev/null +++ b/tests/table/utf8/test_find.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + ["expr", "data"], + [ + (col("col").str.find("oo"), ["foo", "quux"]), + ( + col("col").str.find(lit("oo")), + ["foo", "quux"], + ), + ( + col("col").str.find(col("emptystrings") + lit("oo")), + ["foo", "quux"], + ), + ], +) +def test_series_utf8_find_broadcast_pattern(expr, data) -> None: + table = MicroPartition.from_pydict({"col": data, "emptystrings": ["", ""]}) + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"col": [1, -1]}