diff --git a/daft/daft.pyi b/daft/daft.pyi index e52274f087..4bc03af0e3 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -857,6 +857,7 @@ class PyExpr: def image_crop(self, bbox: PyExpr) -> PyExpr: ... def list_join(self, delimiter: PyExpr) -> PyExpr: ... def list_lengths(self) -> PyExpr: ... + def list_get(self, idx: PyExpr) -> PyExpr: ... def url_download( self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig ) -> PyExpr: ... @@ -928,6 +929,7 @@ class PySeries: def partitioning_months(self) -> PySeries: ... def partitioning_years(self) -> PySeries: ... def list_lengths(self) -> PySeries: ... + def list_get(self, idx: PySeries) -> PySeries: ... def image_decode(self) -> PySeries: ... def image_encode(self, image_format: ImageFormat) -> PySeries: ... def image_resize(self, w: int, h: int) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index a5e9baebf0..eb47eed44b 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -671,6 +671,19 @@ def lengths(self) -> Expression: """ return Expression._from_pyexpr(self._expr.list_lengths()) + def get(self, idx: int | Expression, default: object = None) -> Expression: + """Gets the element at an index in each list + + Args: + idx: index or indices to retrieve from each list + default: the default value if the specified index is out of bounds + + Returns: + Expression: an expression with the type of the list values + """ + idx_expr = Expression._to_expression(idx) + return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr)) + class ExpressionsProjection(Iterable[Expression]): """A collection of Expressions that can be projected onto a Table to produce another Table diff --git a/daft/series.py b/daft/series.py index 807f23bbe2..ee0af5ca68 100644 --- a/daft/series.py +++ b/daft/series.py @@ -600,6 +600,9 @@ class SeriesListNamespace(SeriesNamespace): def lengths(self) -> Series: return Series._from_pyseries(self._series.list_lengths()) + def get(self, idx: Series, default: Series) -> Series: + return Series._from_pyseries(self._series.list_get(idx._series)) + class SeriesImageNamespace(SeriesNamespace): def decode(self) -> Series: diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index f6e37c6fea..a7ca235685 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -4,7 +4,7 @@ use crate::array::{ growable::{make_growable, Growable}, FixedSizeListArray, ListArray, }; -use crate::datatypes::{UInt64Array, Utf8Array}; +use crate::datatypes::{Int64Array, UInt64Array, Utf8Array}; use crate::DataType; use crate::series::Series; @@ -112,6 +112,55 @@ impl ListArray { Box::new(arrow2::array::Utf8Array::from_iter(result)), ))) } + + fn get_children_helper( + &self, + idx_iter: &mut impl Iterator>, + ) -> DaftResult { + let mut growable = make_growable( + self.name(), + self.child_data_type(), + vec![&self.flat_child], + true, + self.len(), + ); + + let offsets = self.offsets(); + + for i in 0..self.len() { + let is_valid = self.is_valid(i); + let start = *offsets.get(i).unwrap(); + let end = *offsets.get(i + 1).unwrap(); + let child_idx = idx_iter.next().unwrap().unwrap(); + + // only add index value when list is valid and index is within bounds + match (is_valid, child_idx >= 0, start + child_idx, end + child_idx) { + (true, true, idx_offset, _) if idx_offset < end => { + growable.extend(0, idx_offset as usize, 1) + } + (true, false, _, idx_offset) if idx_offset >= start => { + growable.extend(0, idx_offset as usize, 1) + } + _ => growable.add_nulls(1), + }; + } + + growable.build() + } + + pub fn get_children(&self, idx: &Int64Array) -> DaftResult { + match idx.len() { + 1 => { + let mut idx_iter = repeat(idx.get(0)).take(self.len()); + self.get_children_helper(&mut idx_iter) + } + _ => { + assert_eq!(idx.len(), self.len()); + let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied()); + self.get_children_helper(&mut idx_iter) + } + } + } } impl FixedSizeListArray { @@ -190,4 +239,49 @@ impl FixedSizeListArray { Box::new(arrow2::array::Utf8Array::from_iter(result)), ))) } + + fn get_children_helper( + &self, + idx_iter: &mut impl Iterator>, + ) -> DaftResult { + let mut growable = make_growable( + self.name(), + self.child_data_type(), + vec![&self.flat_child], + true, + self.len(), + ); + + let list_size = self.fixed_element_len(); + + for i in 0..self.len() { + let is_valid = self.is_valid(i); + let child_idx = idx_iter.next().unwrap().unwrap(); + + // only add index value when list is valid and index is within bounds + match (is_valid, child_idx.abs() < list_size as i64, child_idx >= 0) { + (true, true, true) => growable.extend(0, i * list_size + child_idx as usize, 1), + (true, true, false) => { + growable.extend(0, (i + 1) * list_size + child_idx as usize, 1) + } + _ => growable.add_nulls(1), + }; + } + + growable.build() + } + + pub fn get_children(&self, idx: &Int64Array) -> DaftResult { + match idx.len() { + 1 => { + let mut idx_iter = repeat(idx.get(0)).take(self.len()); + self.get_children_helper(&mut idx_iter) + } + _ => { + assert_eq!(idx.len(), self.len()); + let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied()); + self.get_children_helper(&mut idx_iter) + } + } + } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index ae63e2832f..665f1c858f 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -308,6 +308,10 @@ impl PySeries { Ok(self.series.list_lengths()?.into_series().into()) } + pub fn list_get(&self, idx: &Self) -> PyResult { + Ok(self.series.list_get(&idx.series)?.into()) + } + pub fn image_decode(&self) -> PyResult { Ok(self.series.image_decode()?.into()) } diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 73972aa34b..ba5c85418f 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -53,4 +53,20 @@ impl Series { ))), } } + + pub fn list_get(&self, idx: &Series) -> DaftResult { + use DataType::*; + + let idx = idx.cast(&Int64)?; + let idx_arr = idx.i64().unwrap(); + + match self.data_type() { + List(_) => self.list()?.get_children(idx_arr), + FixedSizeList(..) => self.fixed_size_list()?.get_children(idx_arr), + dt => Err(DaftError::TypeError(format!( + "Get not implemented for {}", + dt + ))), + } + } } diff --git a/src/daft-dsl/src/functions/list/get.rs b/src/daft-dsl/src/functions/list/get.rs new file mode 100644 index 0000000000..688acdd6d9 --- /dev/null +++ b/src/daft-dsl/src/functions/list/get.rs @@ -0,0 +1,47 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct GetEvaluator {} + +impl FunctionEvaluator for GetEvaluator { + fn fn_name(&self) -> &'static str { + "get" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input, idx] => { + let input_field = input.to_field(schema)?; + let idx_field = idx.to_field(schema)?; + + if !idx_field.dtype.is_integer() { + return Err(DaftError::TypeError(format!( + "Expected get index to be integer, received: {}", + idx_field.dtype + ))); + } + + let exploded_field = input_field.to_exploded_field()?; + Ok(exploded_field) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input, idx] => Ok(input.list_get(idx)?), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index 45340cc04f..0693c59277 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -1,8 +1,10 @@ mod explode; +mod get; mod join; mod lengths; use explode::ExplodeEvaluator; +use get::GetEvaluator; use join::JoinEvaluator; use lengths::LengthsEvaluator; use serde::{Deserialize, Serialize}; @@ -16,6 +18,7 @@ pub enum ListExpr { Explode, Join, Lengths, + Get, } impl ListExpr { @@ -26,6 +29,7 @@ impl ListExpr { Explode => &ExplodeEvaluator {}, Join => &JoinEvaluator {}, Lengths => &LengthsEvaluator {}, + Get => &GetEvaluator {}, } } } @@ -50,3 +54,10 @@ pub fn lengths(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn get(input: &Expr, idx: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Get), + inputs: vec![input.clone(), idx.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index e43fc1be4a..1859ce51e4 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -353,6 +353,11 @@ impl PyExpr { Ok(lengths(&self.expr).into()) } + pub fn list_get(&self, idx: &Self) -> PyResult { + use crate::functions::list::get; + Ok(get(&self.expr, &idx.expr).into()) + } + pub fn url_download( &self, max_connections: i64,