Skip to content

Commit

Permalink
add list getter (no overrideable default)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Jan 10, 2024
1 parent a994c7b commit a7b5e5e
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 1 deletion.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ class PyExpr:
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_get(self, idx: PyExpr) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...
Expand Down Expand Up @@ -928,6 +929,7 @@ class PySeries:
def partitioning_months(self) -> PySeries: ...
def partitioning_years(self) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_get(self, idx: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
Expand Down
13 changes: 13 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,19 @@ def lengths(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.list_lengths())

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list
Args:
idx: index or indices to retrieve from each list
default: the default value if the specified index is out of bounds
Returns:
Expression: an expression with the type of the list values
"""
idx_expr = Expression._to_expression(idx)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr))

Check warning on line 685 in daft/expressions/expressions.py

View check run for this annotation

Codecov / codecov/patch

daft/expressions/expressions.py#L684-L685

Added lines #L684 - L685 were not covered by tests


class ExpressionsProjection(Iterable[Expression]):
"""A collection of Expressions that can be projected onto a Table to produce another Table
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series))

Check warning on line 604 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L604

Added line #L604 was not covered by tests


class SeriesImageNamespace(SeriesNamespace):
def decode(self) -> Series:
Expand Down
96 changes: 95 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
};
use crate::datatypes::{UInt64Array, Utf8Array};
use crate::datatypes::{Int64Array, UInt64Array, Utf8Array};
use crate::DataType;

use crate::series::Series;
Expand Down Expand Up @@ -112,6 +112,55 @@ impl ListArray {
Box::new(arrow2::array::Utf8Array::from_iter(result)),
)))
}

fn get_children_helper(
&self,
idx_iter: &mut impl Iterator<Item = Option<i64>>,
) -> DaftResult<Series> {
let mut growable = make_growable(
self.name(),
self.child_data_type(),
vec![&self.flat_child],
true,
self.len(),
);

let offsets = self.offsets();

for i in 0..self.len() {
let is_valid = self.is_valid(i);
let start = *offsets.get(i).unwrap();
let end = *offsets.get(i + 1).unwrap();
let child_idx = idx_iter.next().unwrap().unwrap();

// only add index value when list is valid and index is within bounds
match (is_valid, child_idx >= 0, start + child_idx, end + child_idx) {
(true, true, idx_offset, _) if idx_offset < end => {
growable.extend(0, idx_offset as usize, 1)
}
(true, false, _, idx_offset) if idx_offset >= start => {
growable.extend(0, idx_offset as usize, 1)
}
_ => growable.add_nulls(1),
};
}

growable.build()
}

pub fn get_children(&self, idx: &Int64Array) -> DaftResult<Series> {
match idx.len() {
1 => {
let mut idx_iter = repeat(idx.get(0)).take(self.len());
self.get_children_helper(&mut idx_iter)
}
_ => {
assert_eq!(idx.len(), self.len());
let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied());
self.get_children_helper(&mut idx_iter)
}
}
}
}

impl FixedSizeListArray {
Expand Down Expand Up @@ -190,4 +239,49 @@ impl FixedSizeListArray {
Box::new(arrow2::array::Utf8Array::from_iter(result)),
)))
}

fn get_children_helper(
&self,
idx_iter: &mut impl Iterator<Item = Option<i64>>,
) -> DaftResult<Series> {
let mut growable = make_growable(
self.name(),
self.child_data_type(),
vec![&self.flat_child],
true,
self.len(),
);

let list_size = self.fixed_element_len();

for i in 0..self.len() {
let is_valid = self.is_valid(i);
let child_idx = idx_iter.next().unwrap().unwrap();

// only add index value when list is valid and index is within bounds
match (is_valid, child_idx.abs() < list_size as i64, child_idx >= 0) {
(true, true, true) => growable.extend(0, i * list_size + child_idx as usize, 1),
(true, true, false) => {
growable.extend(0, (i + 1) * list_size + child_idx as usize, 1)
}
_ => growable.add_nulls(1),
};
}

growable.build()
}

pub fn get_children(&self, idx: &Int64Array) -> DaftResult<Series> {
match idx.len() {
1 => {
let mut idx_iter = repeat(idx.get(0)).take(self.len());
self.get_children_helper(&mut idx_iter)
}
_ => {
assert_eq!(idx.len(), self.len());
let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied());
self.get_children_helper(&mut idx_iter)
}
}
}
}
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ impl PySeries {
Ok(self.series.list_lengths()?.into_series().into())
}

pub fn list_get(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.list_get(&idx.series)?.into())
}

pub fn image_decode(&self) -> PyResult<Self> {
Ok(self.series.image_decode()?.into())
}
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,20 @@ impl Series {
))),
}
}

pub fn list_get(&self, idx: &Series) -> DaftResult<Series> {
use DataType::*;

let idx = idx.cast(&Int64)?;
let idx_arr = idx.i64().unwrap();

match self.data_type() {
List(_) => self.list()?.get_children(idx_arr),
FixedSizeList(..) => self.fixed_size_list()?.get_children(idx_arr),
dt => Err(DaftError::TypeError(format!(
"Get not implemented for {}",
dt
))),
}
}
}
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/list/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use crate::Expr;
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct GetEvaluator {}

impl FunctionEvaluator for GetEvaluator {
fn fn_name(&self) -> &'static str {
"get"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[input, idx] => {
let input_field = input.to_field(schema)?;
let idx_field = idx.to_field(schema)?;

if !idx_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected get index to be integer, received: {}",
idx_field.dtype
)));
}

let exploded_field = input_field.to_exploded_field()?;
Ok(exploded_field)
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[input, idx] => Ok(input.list_get(idx)?),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/list/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod explode;
mod get;
mod join;
mod lengths;

use explode::ExplodeEvaluator;
use get::GetEvaluator;
use join::JoinEvaluator;
use lengths::LengthsEvaluator;
use serde::{Deserialize, Serialize};
Expand All @@ -16,6 +18,7 @@ pub enum ListExpr {
Explode,
Join,
Lengths,
Get,
}

impl ListExpr {
Expand All @@ -26,6 +29,7 @@ impl ListExpr {
Explode => &ExplodeEvaluator {},
Join => &JoinEvaluator {},
Lengths => &LengthsEvaluator {},
Get => &GetEvaluator {},
}
}
}
Expand All @@ -50,3 +54,10 @@ pub fn lengths(input: &Expr) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn get(input: &Expr, idx: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::List(ListExpr::Get),
inputs: vec![input.clone(), idx.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ impl PyExpr {
Ok(lengths(&self.expr).into())
}

pub fn list_get(&self, idx: &Self) -> PyResult<Self> {
use crate::functions::list::get;
Ok(get(&self.expr, &idx.expr).into())
}

pub fn url_download(
&self,
max_connections: i64,
Expand Down

0 comments on commit a7b5e5e

Please sign in to comment.