Skip to content

Commit

Permalink
add support to handle all types of integer
Browse files Browse the repository at this point in the history
  • Loading branch information
murex971 committed Mar 21, 2024
1 parent 9252174 commit 38f0dc7
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 32 deletions.
58 changes: 38 additions & 20 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::{
array::ListArray,
datatypes::{BooleanArray, Field, UInt32Array, UInt64Array, Utf8Array},
array::{DataArray, ListArray},
datatypes::{BooleanArray, DaftIntegerType, DaftNumericType, Field, UInt64Array, Utf8Array},
DataType, Series,
};
use arrow2::{self};
use arrow2::{self, array::Array};

use common_error::{DaftError, DaftResult};
use num_traits::NumCast;

use super::{as_arrow::AsArrow, full::FullNull};

Expand Down Expand Up @@ -260,26 +261,35 @@ impl Utf8Array {
Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

pub fn left(&self, n: &UInt32Array) -> DaftResult<Utf8Array> {
pub fn left<I>(&self, n: &DataArray<I>) -> DaftResult<Utf8Array>
where
I: DaftIntegerType,
<I as DaftNumericType>::Native: Ord,
{
let self_arrow = self.as_arrow();
let n_arrow = n.as_arrow();
// Handle empty cases.
if self.is_empty() || n.is_empty() {
if self.is_empty() || n_arrow.is_empty() {
return Ok(Utf8Array::empty(self.name(), self.data_type()));
}
match (self.len(), n.len()) {
match (self.len(), n_arrow.len()) {
// Matching len case:
(self_len, n_len) if self_len == n_len => {
let arrow_result = self_arrow
.iter()
.zip(n_arrow.iter())
.map(|(val, n)| match (val, n) {
(Some(val), Some(pat)) => {
Some(val.chars().take(*pat as usize).collect::<String>())
(Some(val), Some(nchar)) => {
let nchar: usize = NumCast::from(*nchar).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {nchar}"
))
})?;
Ok(Some(val.chars().take(nchar).collect::<String>()))
}
_ => None,
_ => Ok(None),
})
.collect::<arrow2::array::Utf8Array<i64>>();
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>()?;

Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}
Expand All @@ -293,11 +303,17 @@ impl Utf8Array {
self_len,
)),
Some(n_scalar_value) => {
let n_scalar_value: usize =
NumCast::from(n_scalar_value).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {n_scalar_value}"
))
})?;
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.chars().take(n_scalar_value as usize).collect::<String>())
Some(v.chars().take(n_scalar_value).collect::<String>())
})
.collect::<arrow2::array::Utf8Array<i64>>();

Expand All @@ -313,16 +329,18 @@ impl Utf8Array {
Some(self_scalar_value) => {
let arrow_result = n_arrow
.iter()
.map(|n| {
let n = n?;
Some(
self_scalar_value
.chars()
.take(*n as usize)
.collect::<String>(),
)
.map(|n| match n {
None => Ok(None),
Some(n) => {
let n: usize = NumCast::from(*n).ok_or_else(|| {
DaftError::ComputeError(format!(
"failed to cast rhs as usize {n}"
))
})?;
Ok(Some(self_scalar_value.chars().take(n).collect::<String>()))
}
})
.collect::<arrow2::array::Utf8Array<i64>>();
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>()?;

Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/series/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl Series {
self.downcast()
}

pub fn i128(&self) -> DaftResult<&UInt64Array> {
self.downcast()
}

pub fn u8(&self) -> DaftResult<&UInt8Array> {
self.downcast()
}
Expand Down
35 changes: 30 additions & 5 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,36 @@ impl Series {
}
}

pub fn utf8_left(&self, pattern: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Utf8 => Ok(self.utf8()?.left(pattern.u32()?)?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
pub fn utf8_left(&self, nchars: &Series) -> DaftResult<Series> {
match (self.data_type(), nchars.data_type()) {
(DataType::Utf8, DataType::UInt32) => {
Ok(self.utf8()?.left(nchars.u32()?)?.into_series())
}
(DataType::Utf8, DataType::Int32) => {
Ok(self.utf8()?.left(nchars.i32()?)?.into_series())
}
(DataType::Utf8, DataType::UInt64) => {
Ok(self.utf8()?.left(nchars.u64()?)?.into_series())
}
(DataType::Utf8, DataType::Int64) => {
Ok(self.utf8()?.left(nchars.i64()?)?.into_series())
}
(DataType::Utf8, DataType::Int8) => Ok(self.utf8()?.left(nchars.i8()?)?.into_series()),
(DataType::Utf8, DataType::UInt8) => Ok(self.utf8()?.left(nchars.u8()?)?.into_series()),
(DataType::Utf8, DataType::Int16) => {
Ok(self.utf8()?.left(nchars.i16()?)?.into_series())
}
(DataType::Utf8, DataType::UInt16) => {
Ok(self.utf8()?.left(nchars.u16()?)?.into_series())
}
(DataType::Utf8, DataType::Int128) => {
Ok(self.utf8()?.left(nchars.i128()?)?.into_series())
}
(DataType::Null, dt) if dt.is_integer() => Ok(self.clone()),
(DataType::Utf8, dt) => Err(DaftError::TypeError(format!(
"Left not implemented for nchar type {dt}"
))),
(dt, _) => Err(DaftError::TypeError(format!(
"Left not implemented for type {dt}"
))),
}
Expand Down
12 changes: 6 additions & 6 deletions src/daft-dsl/src/functions/utf8/left.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ impl FunctionEvaluator for LeftEvaluator {

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) {
(Ok(data_field), Ok(pattern_field)) => {
match (&data_field.dtype, &pattern_field.dtype) {
(DataType::Utf8, DataType::UInt32) => {
[data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) {
(Ok(data_field), Ok(nchars_field)) => {
match (&data_field.dtype, &nchars_field.dtype) {
(DataType::Utf8, dt) if dt.is_integer() => {
Ok(Field::new(data_field.name, DataType::Utf8))
}
_ => Err(DaftError::TypeError(format!(
"Expects inputs to left to be utf8 and uint32, but received {data_field} and {pattern_field}",
"Expects inputs to left to be utf8 and uint32, but received {data_field} and {nchars_field}",
))),
}
}
Expand All @@ -40,7 +40,7 @@ impl FunctionEvaluator for LeftEvaluator {

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[data, pattern] => data.utf8_left(pattern),
[data, nchars] => data.utf8_left(nchars),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
Expand Down
16 changes: 15 additions & 1 deletion tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,21 @@ def test_series_utf8_left_mismatch_len() -> None:
s.str.left(nchars)


def test_series_utf8_left_bad_pattern() -> None:
def test_series_utf8_left_bad_nchars() -> None:
s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"]))
with pytest.raises(ValueError):
s.str.left(1)


def test_series_utf8_left_bad_nchars_dtype() -> None:
s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"]))
nchars = Series.from_arrow(pa.array(["1", "2", "3"]))
with pytest.raises(ValueError):
s.str.left(nchars)


def test_series_utf8_left_bad_dtype() -> None:
s = Series.from_arrow(pa.array([1, 2, 3]))
nchars = Series.from_arrow(pa.array([1, 2, 3]))
with pytest.raises(ValueError):
s.str.left(nchars)
10 changes: 10 additions & 0 deletions tests/table/utf8/test_left.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from daft.table import MicroPartition


def test_utf8_left():
table = MicroPartition.from_pydict({"col": ["foo", None, "barBaz", "quux", "1"]})
result = table.eval_expression_list([col("col").str.left(3)])
assert result.to_pydict() == {"col": ["foo", None, "bar", "quu", "1"]}

0 comments on commit 38f0dc7

Please sign in to comment.