From 38f0dc752556881d572cc788e1e44b7be037fd09 Mon Sep 17 00:00:00 2001 From: Nupur Agrawal Date: Thu, 21 Mar 2024 13:24:57 +0530 Subject: [PATCH] add support to handle all types of integer --- src/daft-core/src/array/ops/utf8.rs | 58 ++++++++++++++++-------- src/daft-core/src/series/ops/downcast.rs | 4 ++ src/daft-core/src/series/ops/utf8.rs | 35 ++++++++++++-- src/daft-dsl/src/functions/utf8/left.rs | 12 ++--- tests/series/test_utf8_ops.py | 16 ++++++- tests/table/utf8/test_left.py | 10 ++++ 6 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 tests/table/utf8/test_left.py diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 54d6b4eec5..75ac685c32 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,11 +1,12 @@ use crate::{ - array::ListArray, - datatypes::{BooleanArray, Field, UInt32Array, UInt64Array, Utf8Array}, + array::{DataArray, ListArray}, + datatypes::{BooleanArray, DaftIntegerType, DaftNumericType, Field, UInt64Array, Utf8Array}, DataType, Series, }; -use arrow2::{self}; +use arrow2::{self, array::Array}; use common_error::{DaftError, DaftResult}; +use num_traits::NumCast; use super::{as_arrow::AsArrow, full::FullNull}; @@ -260,26 +261,35 @@ impl Utf8Array { Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } - pub fn left(&self, n: &UInt32Array) -> DaftResult { + pub fn left(&self, n: &DataArray) -> DaftResult + where + I: DaftIntegerType, + ::Native: Ord, + { let self_arrow = self.as_arrow(); let n_arrow = n.as_arrow(); // Handle empty cases. - if self.is_empty() || n.is_empty() { + if self.is_empty() || n_arrow.is_empty() { return Ok(Utf8Array::empty(self.name(), self.data_type())); } - match (self.len(), n.len()) { + match (self.len(), n_arrow.len()) { // Matching len case: (self_len, n_len) if self_len == n_len => { let arrow_result = self_arrow .iter() .zip(n_arrow.iter()) .map(|(val, n)| match (val, n) { - (Some(val), Some(pat)) => { - Some(val.chars().take(*pat as usize).collect::()) + (Some(val), Some(nchar)) => { + let nchar: usize = NumCast::from(*nchar).ok_or_else(|| { + DaftError::ComputeError(format!( + "failed to cast rhs as usize {nchar}" + )) + })?; + Ok(Some(val.chars().take(nchar).collect::())) } - _ => None, + _ => Ok(None), }) - .collect::>(); + .collect::>>()?; Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } @@ -293,11 +303,17 @@ impl Utf8Array { self_len, )), Some(n_scalar_value) => { + let n_scalar_value: usize = + NumCast::from(n_scalar_value).ok_or_else(|| { + DaftError::ComputeError(format!( + "failed to cast rhs as usize {n_scalar_value}" + )) + })?; let arrow_result = self_arrow .iter() .map(|val| { let v = val?; - Some(v.chars().take(n_scalar_value as usize).collect::()) + Some(v.chars().take(n_scalar_value).collect::()) }) .collect::>(); @@ -313,16 +329,18 @@ impl Utf8Array { Some(self_scalar_value) => { let arrow_result = n_arrow .iter() - .map(|n| { - let n = n?; - Some( - self_scalar_value - .chars() - .take(*n as usize) - .collect::(), - ) + .map(|n| match n { + None => Ok(None), + Some(n) => { + let n: usize = NumCast::from(*n).ok_or_else(|| { + DaftError::ComputeError(format!( + "failed to cast rhs as usize {n}" + )) + })?; + Ok(Some(self_scalar_value.chars().take(n).collect::())) + } }) - .collect::>(); + .collect::>>()?; Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index a9c5780fee..cceb874787 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -45,6 +45,10 @@ impl Series { self.downcast() } + pub fn i128(&self) -> DaftResult<&UInt64Array> { + self.downcast() + } + pub fn u8(&self) -> DaftResult<&UInt8Array> { self.downcast() } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index 48bc55710e..13dee4c065 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -120,11 +120,36 @@ impl Series { } } - pub fn utf8_left(&self, pattern: &Series) -> DaftResult { - match self.data_type() { - DataType::Utf8 => Ok(self.utf8()?.left(pattern.u32()?)?.into_series()), - DataType::Null => Ok(self.clone()), - dt => Err(DaftError::TypeError(format!( + pub fn utf8_left(&self, nchars: &Series) -> DaftResult { + match (self.data_type(), nchars.data_type()) { + (DataType::Utf8, DataType::UInt32) => { + Ok(self.utf8()?.left(nchars.u32()?)?.into_series()) + } + (DataType::Utf8, DataType::Int32) => { + Ok(self.utf8()?.left(nchars.i32()?)?.into_series()) + } + (DataType::Utf8, DataType::UInt64) => { + Ok(self.utf8()?.left(nchars.u64()?)?.into_series()) + } + (DataType::Utf8, DataType::Int64) => { + Ok(self.utf8()?.left(nchars.i64()?)?.into_series()) + } + (DataType::Utf8, DataType::Int8) => Ok(self.utf8()?.left(nchars.i8()?)?.into_series()), + (DataType::Utf8, DataType::UInt8) => Ok(self.utf8()?.left(nchars.u8()?)?.into_series()), + (DataType::Utf8, DataType::Int16) => { + Ok(self.utf8()?.left(nchars.i16()?)?.into_series()) + } + (DataType::Utf8, DataType::UInt16) => { + Ok(self.utf8()?.left(nchars.u16()?)?.into_series()) + } + (DataType::Utf8, DataType::Int128) => { + Ok(self.utf8()?.left(nchars.i128()?)?.into_series()) + } + (DataType::Null, dt) if dt.is_integer() => Ok(self.clone()), + (DataType::Utf8, dt) => Err(DaftError::TypeError(format!( + "Left not implemented for nchar type {dt}" + ))), + (dt, _) => Err(DaftError::TypeError(format!( "Left not implemented for type {dt}" ))), } diff --git a/src/daft-dsl/src/functions/utf8/left.rs b/src/daft-dsl/src/functions/utf8/left.rs index 4dcb3cef8a..4abda20bda 100644 --- a/src/daft-dsl/src/functions/utf8/left.rs +++ b/src/daft-dsl/src/functions/utf8/left.rs @@ -18,14 +18,14 @@ impl FunctionEvaluator for LeftEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::UInt32) => { + [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { + (Ok(data_field), Ok(nchars_field)) => { + match (&data_field.dtype, &nchars_field.dtype) { + (DataType::Utf8, dt) if dt.is_integer() => { Ok(Field::new(data_field.name, DataType::Utf8)) } _ => Err(DaftError::TypeError(format!( - "Expects inputs to left to be utf8 and uint32, but received {data_field} and {pattern_field}", + "Expects inputs to left to be utf8 and uint32, but received {data_field} and {nchars_field}", ))), } } @@ -40,7 +40,7 @@ impl FunctionEvaluator for LeftEvaluator { fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { match inputs { - [data, pattern] => data.utf8_left(pattern), + [data, nchars] => data.utf8_left(nchars), _ => Err(DaftError::ValueError(format!( "Expected 2 input args, got {}", inputs.len() diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index 45eb308238..dcab960882 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -414,7 +414,21 @@ def test_series_utf8_left_mismatch_len() -> None: s.str.left(nchars) -def test_series_utf8_left_bad_pattern() -> None: +def test_series_utf8_left_bad_nchars() -> None: s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) with pytest.raises(ValueError): s.str.left(1) + + +def test_series_utf8_left_bad_nchars_dtype() -> None: + s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) + nchars = Series.from_arrow(pa.array(["1", "2", "3"])) + with pytest.raises(ValueError): + s.str.left(nchars) + + +def test_series_utf8_left_bad_dtype() -> None: + s = Series.from_arrow(pa.array([1, 2, 3])) + nchars = Series.from_arrow(pa.array([1, 2, 3])) + with pytest.raises(ValueError): + s.str.left(nchars) diff --git a/tests/table/utf8/test_left.py b/tests/table/utf8/test_left.py new file mode 100644 index 0000000000..764ec24b77 --- /dev/null +++ b/tests/table/utf8/test_left.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_utf8_left(): + table = MicroPartition.from_pydict({"col": ["foo", None, "barBaz", "quux", "1"]}) + result = table.eval_expression_list([col("col").str.left(3)]) + assert result.to_pydict() == {"col": ["foo", None, "bar", "quu", "1"]}