Skip to content

Commit

Permalink
[FEAT] support new_zero/zero_lit in Rust/Python side
Browse files Browse the repository at this point in the history
This commit also fixes a bug when converting struct literal to series
  • Loading branch information
advancedxy committed Oct 28, 2024
1 parent 5228930 commit 8b16405
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 31 deletions.
3 changes: 2 additions & 1 deletion daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def refresh_logger() -> None:
from daft.dataframe import DataFrame
from daft.logical.schema import Schema
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, col, lit, interval
from daft.expressions import Expression, col, lit, interval, zero_lit
from daft.io import (
DataCatalogTable,
DataCatalogType,
Expand Down Expand Up @@ -120,6 +120,7 @@ def refresh_logger() -> None:
"ImageMode",
"ImageFormat",
"lit",
"zero_lit",
"Series",
"TimeUnit",
"register_viz_hook",
Expand Down
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ class PyExpr:
def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ...
def col(name: str) -> PyExpr: ...
def lit(item: Any) -> PyExpr: ...
def zero_value(dt: PyDataType) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
Expand Down
4 changes: 2 additions & 2 deletions daft/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .expressions import Expression, ExpressionsProjection, col, lit, interval
from .expressions import Expression, ExpressionsProjection, col, lit, interval, zero_lit

__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval"]
__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval", "zero_lit"]
34 changes: 34 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from daft.daft import tokenize_encode as _tokenize_encode
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.daft import zero_value as _zero_value
from daft.datatype import DataType, TimeUnit
from daft.dependencies import pa
from daft.expressions.testing import expr_structurally_equal
Expand Down Expand Up @@ -133,6 +134,39 @@ def lit(value: object) -> Expression:
return Expression._from_pyexpr(lit_value)


def zero_lit(dt: DataType) -> Expression:
"""Creates a literal Expression representing a zero value of corresponding data type
Example:
>>> import daft
>>> from daft import DataType
>>> df = daft.from_pydict({"x": [1, 2, 3]})
>>> df = df.with_column("y", daft.zero_lit(DataType.int32()))
>>> df.show()
╭───────┬───────╮
│ x ┆ y │
│ --- ┆ --- │
│ Int64 ┆ Int32 │
╞═══════╪═══════╡
│ 1 ┆ 0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 0 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
Args:
dt: data type of the zero value
Returns:
Expression: representing the zero value of the data type
"""
zero = _zero_value(dt._dtype)
return Expression._from_pyexpr(zero)


def col(name: str) -> Expression:
"""Creates an Expression referring to the column with the provided name.
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(python::interval_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::zero_value, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateful_udf, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(
Expand Down
238 changes: 211 additions & 27 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,67 @@ impl Display for LiteralValue {
}

impl LiteralValue {
pub fn new_zero(dt: &DataType) -> DaftResult<Self> {
Ok(match dt {
DataType::Null => Self::Null,
DataType::Boolean => Self::Boolean(false),
DataType::Utf8 => Self::Utf8(String::new()),
DataType::Binary => Self::Binary(vec![]),
DataType::FixedSizeBinary(usize) => Self::Binary(vec![0; *usize]),
DataType::Int32 => Self::Int32(0),
DataType::UInt32 => Self::UInt32(0),
DataType::Int64 => Self::Int64(0),
DataType::UInt64 => Self::UInt64(0),
DataType::Date => Self::Date(0),
DataType::Time(unit) => Self::Time(0, *unit),
DataType::Timestamp(unit, time_zone) => Self::Timestamp(0, *unit, time_zone.clone()),
DataType::Duration(unit) => Self::Duration(0, *unit),
DataType::Float64 => Self::Float64(0.0),
DataType::Decimal128(precision, scale) => {
Self::Decimal(0, *precision as u8, *scale as i8)
}
DataType::Interval => Self::Interval(IntervalValue::new(0, 0, 0)),
DataType::List(item) => Self::Series(Series::empty("literal", item)),
DataType::FixedSizeList(item, usize) => {
// a list of nulls or zero values?
Self::Series(Series::full_null("literal", item, *usize))
}
// No support for map type yet
// DataType::Map { .. } => {},
#[cfg(feature = "python")]
DataType::Python => {
use pyo3::prelude::*;
Self::Python(PyObjectWrapper(Python::with_gil(|py| py.None())))
}
DataType::Struct(s) => {
let record = s
.iter()
.map(|field| {
let zero = Self::new_zero(&field.dtype);
zero.map(|v| (field.clone(), v))
})
.collect::<DaftResult<IndexMap<_, _>>>()?;
Self::Struct(record)
}
DataType::Int8
| DataType::UInt8
| DataType::Int16
| DataType::UInt16
| DataType::Float32 => {
return Err(DaftError::TypeError(format!(
"Unsupported numeric type: {:?}",
dt
)))
}
_ => {
return Err(DaftError::TypeError(format!(
"Unsupported data type: {:?}",
dt
)))
}
})
}

pub fn get_type(&self) -> DataType {
match self {
Self::Null => DataType::Null,
Expand All @@ -204,57 +265,64 @@ impl LiteralValue {
}
}

pub fn to_series(&self) -> Series {
fn to_series_helper(&self, field_name: Option<&str>) -> Series {
let field_name = field_name.unwrap_or("literal");
match self {
Self::Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(),
Self::Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(),
Self::Null => NullArray::full_null(field_name, &DataType::Null, 1).into_series(),
Self::Boolean(val) => BooleanArray::from((field_name, [*val].as_slice())).into_series(),
Self::Utf8(val) => {
Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series()
Utf8Array::from((field_name, [val.as_str()].as_slice())).into_series()
}
Self::Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(),
Self::Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(),
Self::UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(),
Self::Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(),
Self::UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(),
Self::Binary(val) => BinaryArray::from((field_name, val.as_slice())).into_series(),
Self::Int32(val) => Int32Array::from((field_name, [*val].as_slice())).into_series(),
Self::UInt32(val) => UInt32Array::from((field_name, [*val].as_slice())).into_series(),
Self::Int64(val) => Int64Array::from((field_name, [*val].as_slice())).into_series(),
Self::UInt64(val) => UInt64Array::from((field_name, [*val].as_slice())).into_series(),
Self::Date(val) => {
let physical = Int32Array::from(("literal", [*val].as_slice()));
DateArray::new(Field::new("literal", self.get_type()), physical).into_series()
let physical = Int32Array::from((field_name, [*val].as_slice()));
DateArray::new(Field::new(field_name, self.get_type()), physical).into_series()
}
Self::Time(val, ..) => {
let physical = Int64Array::from(("literal", [*val].as_slice()));
TimeArray::new(Field::new("literal", self.get_type()), physical).into_series()
let physical = Int64Array::from((field_name, [*val].as_slice()));
TimeArray::new(Field::new(field_name, self.get_type()), physical).into_series()
}
Self::Timestamp(val, ..) => {
let physical = Int64Array::from(("literal", [*val].as_slice()));
TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series()
let physical = Int64Array::from((field_name, [*val].as_slice()));
TimestampArray::new(Field::new(field_name, self.get_type()), physical).into_series()
}
Self::Duration(val, ..) => {
let physical = Int64Array::from(("literal", [*val].as_slice()));
DurationArray::new(Field::new("literal", self.get_type()), physical).into_series()
let physical = Int64Array::from((field_name, [*val].as_slice()));
DurationArray::new(Field::new(field_name, self.get_type()), physical).into_series()
}
Self::Interval(val) => IntervalArray::from_values(
"literal",
field_name,
std::iter::once((val.months, val.days, val.nanoseconds)),
)
.into_series(),
Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(),
Self::Float64(val) => Float64Array::from((field_name, [*val].as_slice())).into_series(),
Self::Decimal(val, ..) => {
let physical = Int128Array::from(("literal", [*val].as_slice()));
Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series()
let physical = Int128Array::from((field_name, [*val].as_slice()));
Decimal128Array::new(Field::new(field_name, self.get_type()), physical)
.into_series()
}
Self::Series(series) => series.clone().rename("literal"),
Self::Series(series) => series.clone().rename(field_name),
#[cfg(feature = "python")]
Self::Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(),
Self::Python(val) => PythonArray::from((field_name, vec![val.0.clone()])).into_series(),
Self::Struct(entries) => {
let struct_dtype = DataType::Struct(entries.keys().cloned().collect());
let struct_field = Field::new("literal", struct_dtype);
let struct_field = Field::new(field_name, struct_dtype);

let values = entries.values().map(|v| v.to_series()).collect();
let values = entries
.iter()
.map(|(field, value)| value.to_series_helper(Some(&field.name)))
.collect();
StructArray::new(struct_field, values, None).into_series()
}
}
}

pub fn to_series(&self) -> Series {
self.to_series_helper(None)
}
pub fn display_sql<W: Write>(&self, buffer: &mut W) -> io::Result<()> {
let display_sql_err = Err(io::Error::new(
io::ErrorKind::Other,
Expand Down Expand Up @@ -554,9 +622,12 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult<Series> {

#[cfg(test)]
mod test {
use daft_core::prelude::*;
use common_error::DaftError;
use daft_core::{datatypes::IntervalValue, prelude::*};

use super::LiteralValue;
#[cfg(feature = "python")]
use crate::pyobj_serde::PyObjectWrapper;

#[test]
fn test_literals_to_series() {
Expand Down Expand Up @@ -598,4 +669,117 @@ mod test {
let actual = super::literals_to_series(&values);
assert!(actual.is_err());
}

#[test]
fn test_struct_literal_to_serials() {
let values = vec![LiteralValue::Int32(1), LiteralValue::Int64(2)];
let fields = vec![
Field::new("a", DataType::Int32),
Field::new("b", DataType::Int64),
];
let struct_literal =
LiteralValue::Struct(fields.into_iter().zip(values.into_iter()).collect());
let series = struct_literal.to_series();
assert_eq!(series.len(), 1);
assert_eq!(
series.data_type(),
&DataType::Struct(vec![
Field::new("a", DataType::Int32),
Field::new("b", DataType::Int64),
])
);
}

#[test]
fn test_zero_literal_value() {
let type_and_expected_values = vec![
(DataType::Null, LiteralValue::Null),
(DataType::Boolean, LiteralValue::Boolean(false)),
(DataType::Utf8, LiteralValue::Utf8("".to_string())),
(DataType::Binary, LiteralValue::Binary(vec![])),
(DataType::FixedSizeBinary(1), LiteralValue::Binary(vec![0])),
(DataType::Int32, LiteralValue::Int32(0)),
(DataType::UInt32, LiteralValue::UInt32(0)),
(DataType::Int64, LiteralValue::Int64(0)),
(DataType::UInt64, LiteralValue::UInt64(0)),
(DataType::Date, LiteralValue::Date(0)),
(
DataType::Time(TimeUnit::Microseconds),
LiteralValue::Time(0, TimeUnit::Microseconds),
),
(
DataType::Timestamp(TimeUnit::Microseconds, Some("UTC".to_string())),
LiteralValue::Timestamp(0, TimeUnit::Microseconds, Some("UTC".to_string())),
),
(
DataType::Duration(TimeUnit::Microseconds),
LiteralValue::Duration(0, TimeUnit::Microseconds),
),
(DataType::Float64, LiteralValue::Float64(0.0)),
(DataType::Decimal128(1, 1), LiteralValue::Decimal(0, 1, 1)),
(
DataType::Interval,
LiteralValue::Interval(IntervalValue::new(0, 0, 0)),
),
(
DataType::List(Box::new(DataType::Int32)),
LiteralValue::Series(Series::empty("literal", &DataType::Int32)),
),
#[cfg(feature = "python")]
(DataType::Python, {
use pyo3::prelude::*;
LiteralValue::Python(PyObjectWrapper(Python::with_gil(|py| py.None())))
}),
(
DataType::Struct(vec![
Field::new("a", DataType::Int32),
Field::new("b", DataType::Int64),
]),
LiteralValue::Struct(
vec![
(Field::new("a", DataType::Int32), LiteralValue::Int32(0)),
(Field::new("b", DataType::Int64), LiteralValue::Int64(0)),
]
.into_iter()
.collect(),
),
),
];
for (dt, expected) in type_and_expected_values {
let actual = LiteralValue::new_zero(&dt).unwrap();
assert_eq!(expected, actual, "DataType: {:?}", dt);
}

// fixed size list returns all size of null values
let fixed_size_list = DataType::FixedSizeList(Box::new(DataType::Int32), 4);
let actual = LiteralValue::new_zero(&fixed_size_list).unwrap();
let array_arrow = actual.as_series().unwrap().to_arrow();
// the get_type of series is the inner type
assert_eq!(DataType::Int32, actual.get_type());
assert_eq!(4, array_arrow.len());
assert_eq!(4, array_arrow.null_count());
}

#[test]
fn test_unsupported_zero_literal_value() {
let unsupported_types = vec![
DataType::Int8,
DataType::UInt8,
DataType::Int16,
DataType::UInt16,
DataType::Float32,
DataType::Embedding(Box::new(DataType::Int32), 1),
DataType::Map {
key: Box::new(DataType::Int32),
value: Box::new(DataType::Int32),
},
DataType::Image(None),
// others are omitted
];
for dt in unsupported_types {
let actual = LiteralValue::new_zero(&dt);
assert!(actual.is_err());
assert!(matches!(actual.unwrap_err(), DaftError::TypeError(_)));
}
}
}
6 changes: 6 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ pub fn lit(item: Bound<PyAny>) -> PyResult<PyExpr> {
}
}

#[pyfunction]
pub fn zero_value(dType: PyDataType) -> PyResult<PyExpr> {
let literal_val = LiteralValue::new_zero(&dType.dtype)?;
Ok(Expr::Literal(literal_val).into())
}

// Create a UDF Expression using:
// * `func` - a Python function that takes as input an ordered list of Python Series to execute the user's UDF.
// * `expressions` - an ordered list of Expressions, each representing computation that will be performed, producing a Series to pass into `func`
Expand Down
Loading

0 comments on commit 8b16405

Please sign in to comment.