Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add getter for Struct and List expressions #1775

Merged
merged 9 commits into from
Jan 12, 2024
Merged
3 changes: 3 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,8 @@ class PyExpr:
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...
Expand Down Expand Up @@ -928,6 +930,7 @@ class PySeries:
def partitioning_months(self) -> PySeries: ...
def partitioning_years(self) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
Expand Down
32 changes: 32 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def list(self) -> ExpressionListNamespace:
"""Access methods that work on columns of lists"""
return ExpressionListNamespace.from_expression(self)

@property
def struct(self) -> ExpressionStructNamespace:
"""Access methods that work on columns of structs"""
return ExpressionStructNamespace.from_expression(self)

@property
def image(self) -> ExpressionImageNamespace:
"""Access methods that work on columns of images"""
Expand Down Expand Up @@ -671,6 +676,33 @@ def lengths(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.list_lengths())

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list

Args:
idx: index or indices to retrieve from each list
default: the default value if the specified index is out of bounds

Returns:
Expression: an expression with the type of the list values
"""
idx_expr = Expression._to_expression(idx)
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))


class ExpressionStructNamespace(ExpressionNamespace):
def get(self, name: str) -> Expression:
"""Retrieves one field from a struct column

Args:
name: the name of the field to retrieve

Returns:
Expression: the field expression
"""
return Expression._from_pyexpr(self._expr.struct_get(name))


class ExpressionsProjection(Iterable[Expression]):
"""A collection of Expressions that can be projected onto a Table to produce another Table
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))


class SeriesImageNamespace(SeriesNamespace):
def decode(self) -> Series:
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ impl FixedSizeListArray {
_ => unreachable!("FixedSizeListArray should always have FixedSizeList datatype"),
}
}

pub fn with_validity(&self, validity: Option<arrow2::bitmap::Bitmap>) -> DaftResult<Self> {
if let Some(v) = &validity && v.len() != self.len() {
return Err(DaftError::ValueError(format!(
"validity mask length does not match FixedSizeListArray length, {} vs {}",
v.len(),
self.len()
)))
}

Ok(Self::new(
self.field.clone(),
self.flat_child.clone(),
validity,
))
}
}

#[cfg(test)]
Expand Down
17 changes: 17 additions & 0 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,21 @@ impl ListArray {
self.validity.clone(),
))
}

pub fn with_validity(&self, validity: Option<arrow2::bitmap::Bitmap>) -> DaftResult<Self> {
if let Some(v) = &validity && v.len() != self.len() {
return Err(DaftError::ValueError(format!(
"validity mask length does not match ListArray length, {} vs {}",
v.len(),
self.len()
)))
}

Ok(Self::new(
self.field.clone(),
self.flat_child.clone(),
self.offsets.clone(),
validity,
))
}
}
20 changes: 18 additions & 2 deletions src/daft-core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod ops;
pub mod pseudo_arrow;
mod serdes;
mod struct_array;
use arrow2::bitmap::Bitmap;
pub use fixed_size_list_array::FixedSizeListArray;
pub use list_array::ListArray;

Expand Down Expand Up @@ -75,19 +76,34 @@ where
self.len() == 0
}

pub fn with_validity(&self, validity: &[bool]) -> DaftResult<Self> {
pub fn with_validity_slice(&self, validity: &[bool]) -> DaftResult<Self> {
if validity.len() != self.data.len() {
return Err(DaftError::ValueError(format!(
"validity mask length does not match DataArray length, {} vs {}",
validity.len(),
self.data.len()
)));
}
use arrow2::bitmap::Bitmap;
let with_bitmap = self.data.with_validity(Some(Bitmap::from(validity)));
DataArray::new(self.field.clone(), with_bitmap)
}

pub fn with_validity(&self, validity: Option<Bitmap>) -> DaftResult<Self> {
if let Some(v) = &validity && v.len() != self.data.len() {
return Err(DaftError::ValueError(format!(
"validity mask length does not match DataArray length, {} vs {}",
v.len(),
self.data.len()
)));
}
let with_bitmap = self.data.with_validity(validity);
DataArray::new(self.field.clone(), with_bitmap)
}

pub fn validity(&self) -> Option<&Bitmap> {
self.data.validity()
}

pub fn slice(&self, start: usize, end: usize) -> DaftResult<Self> {
if start > end {
return Err(DaftError::ValueError(format!(
Expand Down
32 changes: 16 additions & 16 deletions src/daft-core/src/array/ops/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ mod tests {
let result: Vec<_> = array.equal(2).into_iter().collect();
assert_eq!(result[..], [Some(false), Some(true), Some(false)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.equal(2).into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(false)]);
Ok(())
Expand All @@ -1487,7 +1487,7 @@ mod tests {
let result: Vec<_> = array.not_equal(2).into_iter().collect();
assert_eq!(result[..], [Some(true), Some(false), Some(true)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.not_equal(2).into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(true)]);
Ok(())
Expand All @@ -1500,7 +1500,7 @@ mod tests {
let result: Vec<_> = array.lt(2).into_iter().collect();
assert_eq!(result[..], [Some(true), Some(false), Some(false)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.lt(2).into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(false)]);
Ok(())
Expand All @@ -1513,7 +1513,7 @@ mod tests {
let result: Vec<_> = array.lte(2).into_iter().collect();
assert_eq!(result[..], [Some(true), Some(true), Some(false)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.lte(2).into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(false)]);
Ok(())
Expand All @@ -1526,7 +1526,7 @@ mod tests {
let result: Vec<_> = array.gt(2).into_iter().collect();
assert_eq!(result[..], [Some(false), Some(false), Some(true)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.gt(2).into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(true)]);
Ok(())
Expand All @@ -1539,7 +1539,7 @@ mod tests {
let result: Vec<_> = array.gte(2).into_iter().collect();
assert_eq!(result[..], [Some(false), Some(true), Some(true)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.gte(2).into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(true)]);
Ok(())
Expand All @@ -1552,7 +1552,7 @@ mod tests {
let result: Vec<_> = array.equal(&array)?.into_iter().collect();
assert_eq!(result[..], [Some(true), Some(true), Some(true)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.equal(&array)?.into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(true)]);
Ok(())
Expand All @@ -1565,7 +1565,7 @@ mod tests {
let result: Vec<_> = array.not_equal(&array)?.into_iter().collect();
assert_eq!(result[..], [Some(false), Some(false), Some(false)]);

let array = array.with_validity(&[true, false, true])?;
let array = array.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = array.not_equal(&array)?.into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(false)]);
Ok(())
Expand All @@ -1578,11 +1578,11 @@ mod tests {
let result: Vec<_> = lhs.lt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(false), Some(false), Some(true)]);

let lhs = lhs.with_validity(&[true, false, true])?;
let lhs = lhs.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = lhs.lt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(true)]);

let rhs = rhs.with_validity(&[false, true, true])?;
let rhs = rhs.with_validity_slice(&[false, true, true])?;
let result: Vec<_> = lhs.lt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [None, None, Some(true)]);
Ok(())
Expand All @@ -1595,11 +1595,11 @@ mod tests {
let result: Vec<_> = lhs.lte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(false), Some(true), Some(true)]);

let lhs = lhs.with_validity(&[true, false, true])?;
let lhs = lhs.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = lhs.lte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(false), None, Some(true)]);

let rhs = rhs.with_validity(&[false, true, true])?;
let rhs = rhs.with_validity_slice(&[false, true, true])?;
let result: Vec<_> = lhs.lte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [None, None, Some(true)]);
Ok(())
Expand All @@ -1612,11 +1612,11 @@ mod tests {
let result: Vec<_> = lhs.gt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(true), Some(false), Some(false)]);

let lhs = lhs.with_validity(&[true, false, true])?;
let lhs = lhs.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = lhs.gt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(false)]);

let rhs = rhs.with_validity(&[false, true, true])?;
let rhs = rhs.with_validity_slice(&[false, true, true])?;
let result: Vec<_> = lhs.gt(&rhs)?.into_iter().collect();
assert_eq!(result[..], [None, None, Some(false)]);
Ok(())
Expand All @@ -1629,11 +1629,11 @@ mod tests {
let result: Vec<_> = lhs.gte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(true), Some(true), Some(false)]);

let lhs = lhs.with_validity(&[true, false, true])?;
let lhs = lhs.with_validity_slice(&[true, false, true])?;
let result: Vec<_> = lhs.gte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [Some(true), None, Some(false)]);

let rhs = rhs.with_validity(&[false, true, true])?;
let rhs = rhs.with_validity_slice(&[false, true, true])?;
let result: Vec<_> = lhs.gte(&rhs)?.into_iter().collect();
assert_eq!(result[..], [None, None, Some(false)]);
Ok(())
Expand Down
Loading
Loading