Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add str.lstrip() and str.rstrip() functions #1944

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,8 @@ class PyExpr:
def utf8_length(self) -> PyExpr: ...
def utf8_lower(self) -> PyExpr: ...
def utf8_upper(self) -> PyExpr: ...
def utf8_lstrip(self) -> PyExpr: ...
def utf8_rstrip(self) -> PyExpr: ...
def image_decode(self) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -988,6 +990,8 @@ class PySeries:
def utf8_length(self) -> PySeries: ...
def utf8_lower(self) -> PySeries: ...
def utf8_upper(self) -> PySeries: ...
def utf8_lstrip(self) -> PySeries: ...
def utf8_rstrip(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
22 changes: 22 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,28 @@ def upper(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_upper())

def lstrip(self) -> Expression:
"""Strip whitespace from the left side of a UTF-8 string

Example:
>>> col("x").str.lstrip()

Returns:
Expression: a String expression which is `self` with leading whitespace stripped
"""
return Expression._from_pyexpr(self._expr.utf8_lstrip())

def rstrip(self) -> Expression:
"""Strip whitespace from the right side of a UTF-8 string

Example:
>>> col("x").str.rstrip()

Returns:
Expression: a String expression which is `self` with trailing whitespace stripped
"""
return Expression._from_pyexpr(self._expr.utf8_rstrip())


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
8 changes: 8 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,14 @@ def upper(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_upper())

def lstrip(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_lstrip())

def rstrip(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_rstrip())


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.split
Expression.str.lower
Expression.str.upper
Expression.str.lstrip
Expression.str.rstrip

.. _api-expressions-temporal:

Expand Down
26 changes: 26 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn lstrip(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.trim_start())
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn rstrip(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.trim_end())
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

fn binary_broadcasted_compare<ScalarKernel>(
&self,
other: &Self,
Expand Down
8 changes: 8 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ impl PySeries {
Ok(self.series.utf8_upper()?.into())
}

pub fn utf8_lstrip(&self) -> PyResult<Self> {
Ok(self.series.utf8_lstrip()?.into())
}

pub fn utf8_rstrip(&self) -> PyResult<Self> {
Ok(self.series.utf8_rstrip()?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
20 changes: 20 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,24 @@ impl Series {
))),
}
}

pub fn utf8_lstrip(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.lstrip()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Lstrip not implemented for type {dt}"
))),
}
}

pub fn utf8_rstrip(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.rstrip()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Rstrip not implemented for type {dt}"
))),
}
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/lstrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct LstripEvaluator {}

impl FunctionEvaluator for LstripEvaluator {
fn fn_name(&self) -> &'static str {
"lstrip"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to lstrip to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_lstrip(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
22 changes: 22 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod contains;
mod endswith;
mod length;
mod lower;
mod lstrip;
mod rstrip;
mod split;
mod startswith;
mod upper;
Expand All @@ -10,6 +12,8 @@ use contains::ContainsEvaluator;
use endswith::EndswithEvaluator;
use length::LengthEvaluator;
use lower::LowerEvaluator;
use lstrip::LstripEvaluator;
use rstrip::RstripEvaluator;
use serde::{Deserialize, Serialize};
use split::SplitEvaluator;
use startswith::StartswithEvaluator;
Expand All @@ -28,6 +32,8 @@ pub enum Utf8Expr {
Length,
Lower,
Upper,
Lstrip,
Rstrip,
}

impl Utf8Expr {
Expand All @@ -42,6 +48,8 @@ impl Utf8Expr {
Length => &LengthEvaluator {},
Lower => &LowerEvaluator {},
Upper => &UpperEvaluator {},
Lstrip => &LstripEvaluator {},
Rstrip => &RstripEvaluator {},
}
}
}
Expand Down Expand Up @@ -94,3 +102,17 @@ pub fn upper(data: &Expr) -> Expr {
inputs: vec![data.clone()],
}
}

pub fn lstrip(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Lstrip),
inputs: vec![data.clone()],
}
}

pub fn rstrip(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Rstrip),
inputs: vec![data.clone()],
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/rstrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct RstripEvaluator {}

impl FunctionEvaluator for RstripEvaluator {
fn fn_name(&self) -> &'static str {
"rstrip"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to rstrip to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_rstrip(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
10 changes: 10 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ impl PyExpr {
Ok(upper(&self.expr).into())
}

pub fn utf8_lstrip(&self) -> PyResult<Self> {
use crate::functions::utf8::lstrip;
Ok(lstrip(&self.expr).into())
}

pub fn utf8_rstrip(&self) -> PyResult<Self> {
use crate::functions::utf8::rstrip;
Ok(rstrip(&self.expr).into())
}

pub fn image_decode(&self) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(&self.expr).into())
Expand Down
20 changes: 20 additions & 0 deletions tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,23 @@ def test_str_upper():
run_kernel=s.str.lower,
resolvable=True,
)


def test_str_lstrip():
s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", " c\t"]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.lstrip(),
run_kernel=s.str.lstrip,
resolvable=True,
)


def test_str_rstrip():
s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", "\tc "]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.rstrip(),
run_kernel=s.str.rstrip,
resolvable=True,
)
32 changes: 32 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,35 @@ def test_series_utf8_upper(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.upper()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["a\t", "b\n", "c\t", "d ", "e"]),
# With at least one null
(["\ta\t", None, "\vc\t", "\td ", "e"], ["a\t", None, "c\t", "d ", "e"]),
# With all nulls
([None] * 4, [None] * 4),
],
)
def test_series_utf8_lstrip(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.lstrip()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["\ta", "\nb", "\vc", "\td", "e"]),
# With at least one null
(["\ta\t", None, "\vc\t", "\td ", "e"], ["\ta", None, "\vc", "\td", "e"]),
# With all nulls
([None] * 4, [None] * 4),
],
)
def test_series_utf8_rstrip(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.rstrip()
assert result.to_pylist() == expected
10 changes: 10 additions & 0 deletions tests/table/utf8/test_lstrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_lstrip():
table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]})
result = table.eval_expression_list([col("col").str.lstrip()])
assert result.to_pydict() == {"col": ["a\t", None, "b\n", "c\t", "d ", "e"]}
10 changes: 10 additions & 0 deletions tests/table/utf8/test_rstrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_rstrip():
table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]})
result = table.eval_expression_list([col("col").str.rstrip()])
assert result.to_pydict() == {"col": ["\ta", None, "\nb", "\vc", "\td", "e"]}
Loading