Skip to content

Commit

Permalink
[FEAT] Add str.reverse() function (#1957)
Browse files Browse the repository at this point in the history
* Added the `reverse` function to match
https://ibis-project.org/reference/expression-strings#ibis.expr.types.strings.StringValue.reverse
* Added tests showing example usage

Closes #1924
  • Loading branch information
nsalerni authored Feb 28, 2024
1 parent e6c4970 commit 64fd3e6
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 0 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ class PyExpr:
def utf8_upper(self) -> PyExpr: ...
def utf8_lstrip(self) -> PyExpr: ...
def utf8_rstrip(self) -> PyExpr: ...
def utf8_reverse(self) -> PyExpr: ...
def image_decode(self) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -1001,6 +1002,7 @@ class PySeries:
def utf8_upper(self) -> PySeries: ...
def utf8_lstrip(self) -> PySeries: ...
def utf8_rstrip(self) -> PySeries: ...
def utf8_reverse(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
Expand Down
11 changes: 11 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,17 @@ def rstrip(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_rstrip())

def reverse(self) -> Expression:
"""Reverse a UTF-8 string
Example:
>>> col("x").str.reverse()
Returns:
Expression: a String expression which is `self` reversed
"""
return Expression._from_pyexpr(self._expr.utf8_reverse())


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,10 @@ def rstrip(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_rstrip())

def reverse(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_reverse())


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.upper
Expression.str.lstrip
Expression.str.rstrip
Expression.str.reverse

.. _api-expressions-temporal:

Expand Down
13 changes: 13 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn reverse(&self) -> DaftResult<Utf8Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.chars().rev().collect::<String>())
})
.collect::<arrow2::array::Utf8Array<i64>>()
.with_validity(self_arrow.validity().cloned());
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

fn binary_broadcasted_compare<ScalarKernel>(
&self,
other: &Self,
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ impl PySeries {
Ok(self.series.utf8_rstrip()?.into())
}

pub fn utf8_reverse(&self) -> PyResult<Self> {
Ok(self.series.utf8_reverse()?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,14 @@ impl Series {
))),
}
}

pub fn utf8_reverse(&self) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.reverse()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Reverse not implemented for type {dt}"
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod endswith;
mod length;
mod lower;
mod lstrip;
mod reverse;
mod rstrip;
mod split;
mod startswith;
Expand All @@ -13,6 +14,7 @@ use endswith::EndswithEvaluator;
use length::LengthEvaluator;
use lower::LowerEvaluator;
use lstrip::LstripEvaluator;
use reverse::ReverseEvaluator;
use rstrip::RstripEvaluator;
use serde::{Deserialize, Serialize};
use split::SplitEvaluator;
Expand All @@ -34,6 +36,7 @@ pub enum Utf8Expr {
Upper,
Lstrip,
Rstrip,
Reverse,
}

impl Utf8Expr {
Expand All @@ -50,6 +53,7 @@ impl Utf8Expr {
Upper => &UpperEvaluator {},
Lstrip => &LstripEvaluator {},
Rstrip => &RstripEvaluator {},
Reverse => &ReverseEvaluator {},
}
}
}
Expand Down Expand Up @@ -116,3 +120,10 @@ pub fn rstrip(data: &Expr) -> Expr {
inputs: vec![data.clone()],
}
}

pub fn reverse(data: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Reverse),
inputs: vec![data.clone()],
}
}
46 changes: 46 additions & 0 deletions src/daft-dsl/src/functions/utf8/reverse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct ReverseEvaluator {}

impl FunctionEvaluator for ReverseEvaluator {
fn fn_name(&self) -> &'static str {
"reverse"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to reverse to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_reverse(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ impl PyExpr {
Ok(rstrip(&self.expr).into())
}

pub fn utf8_reverse(&self) -> PyResult<Self> {
use crate::functions::utf8::reverse;
Ok(reverse(&self.expr).into())
}

pub fn image_decode(&self) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(&self.expr).into())
Expand Down
10 changes: 10 additions & 0 deletions tests/expressions/typing/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def test_str_rstrip():
run_kernel=s.str.rstrip,
resolvable=True,
)


def test_str_reverse():
s = Series.from_arrow(pa.array(["abc", "def", "ghi", None, ""]), name="arg")
assert_typing_resolve_vs_runtime_behavior(
data=[s],
expr=col(s.name()).str.reverse(),
run_kernel=s.str.reverse,
resolvable=True,
)
20 changes: 20 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,23 @@ def test_series_utf8_rstrip(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.rstrip()
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["data", "expected"],
[
(["abc", "def", "ghi"], ["cba", "fed", "ihg"]),
# With at least one null
(["abc", None, "def", "ghi"], ["cba", None, "fed", "ihg"]),
# With all nulls
([None] * 4, [None] * 4),
# With emojis
(["😃😌😝", "abc😃😄😅"], ["😝😌😃", "😅😄😃cba"]),
# With non-latin alphabet
(["こんにちは", "こんにちはa", "こんにちはa😄😃"], ["はちにんこ", "aはちにんこ", "😃😄aはちにんこ"]),
],
)
def test_series_utf8_reverse(data, expected) -> None:
s = Series.from_arrow(pa.array(data))
result = s.str.reverse()
assert result.to_pylist() == expected
10 changes: 10 additions & 0 deletions tests/table/utf8/test_reverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_reverse():
table = MicroPartition.from_pydict({"col": ["abc", None, "def", "ghi"]})
result = table.eval_expression_list([col("col").str.reverse()])
assert result.to_pydict() == {"col": ["cba", None, "fed", "ihg"]}

0 comments on commit 64fd3e6

Please sign in to comment.