Skip to content

Commit

Permalink
[FEAT] Add str.lstrip() and str.rstrip() functions (#1944)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsalerni authored Feb 23, 2024
1 parent 5b2fe98 commit bfa7621
Show file tree
Hide file tree
Showing 15 changed files with 286 additions and 0 deletions.
4 changes: 4 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,8 @@ class PyExpr:
def utf8_length(self) -> PyExpr: ...
def utf8_lower(self) -> PyExpr: ...
def utf8_upper(self) -> PyExpr: ...
def utf8_lstrip(self) -> PyExpr: ...
def utf8_rstrip(self) -> PyExpr: ...
def image_decode(self) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -988,6 +990,8 @@ class PySeries:
def utf8_length(self) -> PySeries: ...
def utf8_lower(self) -> PySeries: ...
def utf8_upper(self) -> PySeries: ...
def utf8_lstrip(self) -> PySeries: ...
def utf8_rstrip(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
22 changes: 22 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,28 @@ def upper(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_upper())

def lstrip(self) -> Expression:
"""Strip whitespace from the left side of a UTF-8 string
Example:
>>> col("x").str.lstrip()
Returns:
Expression: a String expression which is `self` with leading whitespace stripped
"""
return Expression._from_pyexpr(self._expr.utf8_lstrip())

def rstrip(self) -> Expression:
"""Strip whitespace from the right side of a UTF-8 string
Example:
>>> col("x").str.rstrip()
Returns:
Expression: a String expression which is `self` with trailing whitespace stripped
"""
return Expression._from_pyexpr(self._expr.utf8_rstrip())


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
8 changes: 8 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,14 @@ def upper(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_upper())

def lstrip(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_lstrip())

def rstrip(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_rstrip())


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.split
Expression.str.lower
Expression.str.upper
Expression.str.lstrip
Expression.str.rstrip

.. _api-expressions-temporal:

Expand Down
26 changes: 26 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn lstrip(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.trim_start())
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn rstrip(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.trim_end())
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

fn binary_broadcasted_compare<ScalarKernel>(
&self,
other: &Self,
Expand Down
8 changes: 8 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ impl PySeries {
Ok(self.series.utf8_upper()?.into())
}

pub fn utf8_lstrip(&self) -> PyResult<Self> {
Ok(self.series.utf8_lstrip()?.into())
}

pub fn utf8_rstrip(&self) -> PyResult<Self> {
Ok(self.series.utf8_rstrip()?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
20 changes: 20 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,24 @@ impl Series {
))),
}
}

pub fn utf8_lstrip(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.lstrip()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Lstrip not implemented for type {dt}"
))),
}
}

pub fn utf8_rstrip(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.rstrip()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Rstrip not implemented for type {dt}"
))),
}
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/lstrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct LstripEvaluator {}

impl FunctionEvaluator for LstripEvaluator {
fn fn_name(&self) -> &'static str {
"lstrip"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to lstrip to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_lstrip(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
22 changes: 22 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod contains;
mod endswith;
mod length;
mod lower;
mod lstrip;
mod rstrip;
mod split;
mod startswith;
mod upper;
Expand All @@ -10,6 +12,8 @@ use contains::ContainsEvaluator;
use endswith::EndswithEvaluator;
use length::LengthEvaluator;
use lower::LowerEvaluator;
use lstrip::LstripEvaluator;
use rstrip::RstripEvaluator;
use serde::{Deserialize, Serialize};
use split::SplitEvaluator;
use startswith::StartswithEvaluator;
Expand All @@ -28,6 +32,8 @@ pub enum Utf8Expr {
Length,
Lower,
Upper,
Lstrip,
Rstrip,
}

impl Utf8Expr {
Expand All @@ -42,6 +48,8 @@ impl Utf8Expr {
Length => &LengthEvaluator {},
Lower => &LowerEvaluator {},
Upper => &UpperEvaluator {},
Lstrip => &LstripEvaluator {},
Rstrip => &RstripEvaluator {},
}
}
}
Expand Down Expand Up @@ -94,3 +102,17 @@ pub fn upper(data: &Expr) -> Expr {
inputs: vec![data.clone()],
}
}

pub fn lstrip(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Lstrip),
inputs: vec![data.clone()],
}
}

pub fn rstrip(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Rstrip),
inputs: vec![data.clone()],
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/rstrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct RstripEvaluator {}

impl FunctionEvaluator for RstripEvaluator {
fn fn_name(&self) -> &'static str {
"rstrip"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to rstrip to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_rstrip(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
10 changes: 10 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ impl PyExpr {
Ok(upper(&self.expr).into())
}

pub fn utf8_lstrip(&self) -> PyResult<Self> {
use crate::functions::utf8::lstrip;
Ok(lstrip(&self.expr).into())
}

pub fn utf8_rstrip(&self) -> PyResult<Self> {
use crate::functions::utf8::rstrip;
Ok(rstrip(&self.expr).into())
}

pub fn image_decode(&self) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(&self.expr).into())
Expand Down
20 changes: 20 additions & 0 deletions tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,23 @@ def test_str_upper():
run_kernel=s.str.lower,
resolvable=True,
)


def test_str_lstrip():
s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", " c\t"]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.lstrip(),
run_kernel=s.str.lstrip,
resolvable=True,
)


def test_str_rstrip():
s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", "\tc "]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.rstrip(),
run_kernel=s.str.rstrip,
resolvable=True,
)
32 changes: 32 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,35 @@ def test_series_utf8_upper(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.upper()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["a\t", "b\n", "c\t", "d ", "e"]),
# With at least one null
(["\ta\t", None, "\vc\t", "\td ", "e"], ["a\t", None, "c\t", "d ", "e"]),
# With all nulls
([None] * 4, [None] * 4),
],
)
def test_series_utf8_lstrip(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.lstrip()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["\ta", "\nb", "\vc", "\td", "e"]),
# With at least one null
(["\ta\t", None, "\vc\t", "\td ", "e"], ["\ta", None, "\vc", "\td", "e"]),
# With all nulls
([None] * 4, [None] * 4),
],
)
def test_series_utf8_rstrip(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.rstrip()
assert result.to_pylist() == expected
10 changes: 10 additions & 0 deletions tests/table/utf8/test_lstrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_lstrip():
table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]})
result = table.eval_expression_list([col("col").str.lstrip()])
assert result.to_pydict() == {"col": ["a\t", None, "b\n", "c\t", "d ", "e"]}
10 changes: 10 additions & 0 deletions tests/table/utf8/test_rstrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_rstrip():
table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]})
result = table.eval_expression_list([col("col").str.rstrip()])
assert result.to_pydict() == {"col": ["\ta", None, "\nb", "\vc", "\td", "e"]}

0 comments on commit bfa7621

Please sign in to comment.