Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add getter for Struct and List expressions #1775

Merged
merged 9 commits into from
Jan 12, 2024
Merged
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ class PyExpr:
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_get(self, idx: PyExpr) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...
Expand Down Expand Up @@ -928,6 +929,7 @@ class PySeries:
def partitioning_months(self) -> PySeries: ...
def partitioning_years(self) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_get(self, idx: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
Expand Down
13 changes: 13 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,19 @@
"""
return Expression._from_pyexpr(self._expr.list_lengths())

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list

Args:
idx: index or indices to retrieve from each list
default: the default value if the specified index is out of bounds

Returns:
Expression: an expression with the type of the list values
"""
idx_expr = Expression._to_expression(idx)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr))

Check warning on line 685 in daft/expressions/expressions.py

View check run for this annotation

Codecov / codecov/patch

daft/expressions/expressions.py#L684-L685

Added lines #L684 - L685 were not covered by tests


class ExpressionsProjection(Iterable[Expression]):
"""A collection of Expressions that can be projected onto a Table to produce another Table
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series))

Check warning on line 604 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L604

Added line #L604 was not covered by tests


class SeriesImageNamespace(SeriesNamespace):
def decode(self) -> Series:
Expand Down
96 changes: 95 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
};
use crate::datatypes::{UInt64Array, Utf8Array};
use crate::datatypes::{Int64Array, UInt64Array, Utf8Array};
use crate::DataType;

use crate::series::Series;
Expand Down Expand Up @@ -112,6 +112,55 @@ impl ListArray {
Box::new(arrow2::array::Utf8Array::from_iter(result)),
)))
}

fn get_children_helper(
&self,
idx_iter: &mut impl Iterator<Item = Option<i64>>,
) -> DaftResult<Series> {
let mut growable = make_growable(
self.name(),
self.child_data_type(),
vec![&self.flat_child],
true,
self.len(),
);

let offsets = self.offsets();

for i in 0..self.len() {
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
let is_valid = self.is_valid(i);
let start = *offsets.get(i).unwrap();
let end = *offsets.get(i + 1).unwrap();
let child_idx = idx_iter.next().unwrap().unwrap();

// only add index value when list is valid and index is within bounds
match (is_valid, child_idx >= 0, start + child_idx, end + child_idx) {
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
(true, true, idx_offset, _) if idx_offset < end => {
growable.extend(0, idx_offset as usize, 1)
}
(true, false, _, idx_offset) if idx_offset >= start => {
growable.extend(0, idx_offset as usize, 1)
}
_ => growable.add_nulls(1),
};
}

growable.build()
}

pub fn get_children(&self, idx: &Int64Array) -> DaftResult<Series> {
match idx.len() {
1 => {
let mut idx_iter = repeat(idx.get(0)).take(self.len());
self.get_children_helper(&mut idx_iter)
}
_ => {
assert_eq!(idx.len(), self.len());
let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied());
self.get_children_helper(&mut idx_iter)
}
}
}
}

impl FixedSizeListArray {
Expand Down Expand Up @@ -190,4 +239,49 @@ impl FixedSizeListArray {
Box::new(arrow2::array::Utf8Array::from_iter(result)),
)))
}

fn get_children_helper(
&self,
idx_iter: &mut impl Iterator<Item = Option<i64>>,
) -> DaftResult<Series> {
let mut growable = make_growable(
self.name(),
self.child_data_type(),
vec![&self.flat_child],
true,
self.len(),
);

let list_size = self.fixed_element_len();

for i in 0..self.len() {
let is_valid = self.is_valid(i);
let child_idx = idx_iter.next().unwrap().unwrap();

// only add index value when list is valid and index is within bounds
match (is_valid, child_idx.abs() < list_size as i64, child_idx >= 0) {
(true, true, true) => growable.extend(0, i * list_size + child_idx as usize, 1),
(true, true, false) => {
growable.extend(0, (i + 1) * list_size + child_idx as usize, 1)
}
_ => growable.add_nulls(1),
};
}

growable.build()
}

pub fn get_children(&self, idx: &Int64Array) -> DaftResult<Series> {
match idx.len() {
1 => {
let mut idx_iter = repeat(idx.get(0)).take(self.len());
self.get_children_helper(&mut idx_iter)
}
_ => {
assert_eq!(idx.len(), self.len());
let mut idx_iter = idx.as_arrow().iter().map(|x| x.copied());
self.get_children_helper(&mut idx_iter)
}
}
}
}
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ impl PySeries {
Ok(self.series.list_lengths()?.into_series().into())
}

pub fn list_get(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.list_get(&idx.series)?.into())
}

pub fn image_decode(&self) -> PyResult<Self> {
Ok(self.series.image_decode()?.into())
}
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,20 @@ impl Series {
))),
}
}

pub fn list_get(&self, idx: &Series) -> DaftResult<Series> {
use DataType::*;

let idx = idx.cast(&Int64)?;
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
let idx_arr = idx.i64().unwrap();

match self.data_type() {
List(_) => self.list()?.get_children(idx_arr),
FixedSizeList(..) => self.fixed_size_list()?.get_children(idx_arr),
dt => Err(DaftError::TypeError(format!(
"Get not implemented for {}",
dt
))),
}
}
}
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/list/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use crate::Expr;
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct GetEvaluator {}

impl FunctionEvaluator for GetEvaluator {
fn fn_name(&self) -> &'static str {
"get"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[input, idx] => {
let input_field = input.to_field(schema)?;
let idx_field = idx.to_field(schema)?;

if !idx_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected get index to be integer, received: {}",
idx_field.dtype
)));
}

let exploded_field = input_field.to_exploded_field()?;
jaychia marked this conversation as resolved.
Show resolved Hide resolved
Ok(exploded_field)
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[input, idx] => Ok(input.list_get(idx)?),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/list/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod explode;
mod get;
mod join;
mod lengths;

use explode::ExplodeEvaluator;
use get::GetEvaluator;
use join::JoinEvaluator;
use lengths::LengthsEvaluator;
use serde::{Deserialize, Serialize};
Expand All @@ -16,6 +18,7 @@ pub enum ListExpr {
Explode,
Join,
Lengths,
Get,
}

impl ListExpr {
Expand All @@ -26,6 +29,7 @@ impl ListExpr {
Explode => &ExplodeEvaluator {},
Join => &JoinEvaluator {},
Lengths => &LengthsEvaluator {},
Get => &GetEvaluator {},
}
}
}
Expand All @@ -50,3 +54,10 @@ pub fn lengths(input: &Expr) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn get(input: &Expr, idx: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::List(ListExpr::Get),
inputs: vec![input.clone(), idx.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ impl PyExpr {
Ok(lengths(&self.expr).into())
}

pub fn list_get(&self, idx: &Self) -> PyResult<Self> {
use crate::functions::list::get;
Ok(get(&self.expr, &idx.expr).into())
}

pub fn url_download(
&self,
max_connections: i64,
Expand Down
Loading