diff --git a/daft/__init__.py b/daft/__init__.py index dec36e4dc7..bb0771bffa 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -73,7 +73,7 @@ def refresh_logger() -> None: from daft.dataframe import DataFrame from daft.logical.schema import Schema from daft.datatype import DataType, TimeUnit -from daft.expressions import Expression, col, lit, interval +from daft.expressions import Expression, col, lit, interval, zero_lit from daft.io import ( DataCatalogTable, DataCatalogType, @@ -120,6 +120,7 @@ def refresh_logger() -> None: "ImageMode", "ImageFormat", "lit", + "zero_lit", "Series", "TimeUnit", "register_viz_hook", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 7edcae7158..41e6e2ca24 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1135,6 +1135,7 @@ class PyExpr: def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ... def col(name: str) -> PyExpr: ... def lit(item: Any) -> PyExpr: ... +def zero_value(dt: PyDataType) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... diff --git a/daft/expressions/__init__.py b/daft/expressions/__init__.py index 6e07ffe0f7..48debfa1eb 100644 --- a/daft/expressions/__init__.py +++ b/daft/expressions/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .expressions import Expression, ExpressionsProjection, col, lit, interval +from .expressions import Expression, ExpressionsProjection, col, lit, interval, zero_lit -__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval"] +__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval", "zero_lit"] diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index d1b52f6f95..7b938e49f2 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -36,6 +36,7 @@ from daft.daft import tokenize_encode as _tokenize_encode from daft.daft import url_download as _url_download from daft.daft import utf8_count_matches as _utf8_count_matches +from daft.daft import zero_value as _zero_value from daft.datatype import DataType, TimeUnit from daft.dependencies import pa from daft.expressions.testing import expr_structurally_equal @@ -133,6 +134,39 @@ def lit(value: object) -> Expression: return Expression._from_pyexpr(lit_value) +def zero_lit(dt: DataType) -> Expression: + """Creates a literal Expression representing a zero value of corresponding data type + + Example: + >>> import daft + >>> from daft import DataType + >>> df = daft.from_pydict({"x": [1, 2, 3]}) + >>> df = df.with_column("y", daft.zero_lit(DataType.int32())) + >>> df.show() + ╭───────┬───────╮ + │ x ┆ y │ + │ --- ┆ --- │ + │ Int64 ┆ Int32 │ + ╞═══════╪═══════╡ + │ 1 ┆ 0 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 0 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 0 │ + ╰───────┴───────╯ + + (Showing first 3 of 3 rows) + + Args: + dt: data type of the zero value + + Returns: + Expression: representing the zero value of the data type + """ + zero = _zero_value(dt._dtype) + return Expression._from_pyexpr(zero) + + def col(name: str) -> Expression: """Creates an Expression referring to the column with the provided name. diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 65af123fed..7c05bf3b2b 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -39,6 +39,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(python::interval_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::zero_value, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::stateful_udf, parent)?)?; parent.add_function(wrap_pyfunction_bound!( diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 47b4090be5..2d037a92fc 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -178,6 +178,67 @@ impl Display for LiteralValue { } impl LiteralValue { + pub fn new_zero(dt: &DataType) -> DaftResult { + Ok(match dt { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean(false), + DataType::Utf8 => Self::Utf8(String::new()), + DataType::Binary => Self::Binary(vec![]), + DataType::FixedSizeBinary(usize) => Self::Binary(vec![0; *usize]), + DataType::Int32 => Self::Int32(0), + DataType::UInt32 => Self::UInt32(0), + DataType::Int64 => Self::Int64(0), + DataType::UInt64 => Self::UInt64(0), + DataType::Date => Self::Date(0), + DataType::Time(unit) => Self::Time(0, *unit), + DataType::Timestamp(unit, time_zone) => Self::Timestamp(0, *unit, time_zone.clone()), + DataType::Duration(unit) => Self::Duration(0, *unit), + DataType::Float64 => Self::Float64(0.0), + DataType::Decimal128(precision, scale) => { + Self::Decimal(0, *precision as u8, *scale as i8) + } + DataType::Interval => Self::Interval(IntervalValue::new(0, 0, 0)), + DataType::List(item) => Self::Series(Series::empty("literal", item)), + DataType::FixedSizeList(item, usize) => { + // a list of nulls or zero values? + Self::Series(Series::full_null("literal", item, *usize)) + } + // No support for map type yet + // DataType::Map { .. } => {}, + #[cfg(feature = "python")] + DataType::Python => { + use pyo3::prelude::*; + Self::Python(PyObjectWrapper(Python::with_gil(|py| py.None()))) + } + DataType::Struct(s) => { + let record = s + .iter() + .map(|field| { + let zero = Self::new_zero(&field.dtype); + zero.map(|v| (field.clone(), v)) + }) + .collect::>>()?; + Self::Struct(record) + } + DataType::Int8 + | DataType::UInt8 + | DataType::Int16 + | DataType::UInt16 + | DataType::Float32 => { + return Err(DaftError::TypeError(format!( + "Unsupported numeric type: {:?}", + dt + ))) + } + _ => { + return Err(DaftError::TypeError(format!( + "Unsupported data type: {:?}", + dt + ))) + } + }) + } + pub fn get_type(&self) -> DataType { match self { Self::Null => DataType::Null, @@ -204,57 +265,64 @@ impl LiteralValue { } } - pub fn to_series(&self) -> Series { + fn to_series_helper(&self, field_name: Option<&str>) -> Series { + let field_name = field_name.unwrap_or("literal"); match self { - Self::Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), - Self::Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), + Self::Null => NullArray::full_null(field_name, &DataType::Null, 1).into_series(), + Self::Boolean(val) => BooleanArray::from((field_name, [*val].as_slice())).into_series(), Self::Utf8(val) => { - Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series() + Utf8Array::from((field_name, [val.as_str()].as_slice())).into_series() } - Self::Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), - Self::Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), - Self::UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), - Self::Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), - Self::UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Binary(val) => BinaryArray::from((field_name, val.as_slice())).into_series(), + Self::Int32(val) => Int32Array::from((field_name, [*val].as_slice())).into_series(), + Self::UInt32(val) => UInt32Array::from((field_name, [*val].as_slice())).into_series(), + Self::Int64(val) => Int64Array::from((field_name, [*val].as_slice())).into_series(), + Self::UInt64(val) => UInt64Array::from((field_name, [*val].as_slice())).into_series(), Self::Date(val) => { - let physical = Int32Array::from(("literal", [*val].as_slice())); - DateArray::new(Field::new("literal", self.get_type()), physical).into_series() + let physical = Int32Array::from((field_name, [*val].as_slice())); + DateArray::new(Field::new(field_name, self.get_type()), physical).into_series() } Self::Time(val, ..) => { - let physical = Int64Array::from(("literal", [*val].as_slice())); - TimeArray::new(Field::new("literal", self.get_type()), physical).into_series() + let physical = Int64Array::from((field_name, [*val].as_slice())); + TimeArray::new(Field::new(field_name, self.get_type()), physical).into_series() } Self::Timestamp(val, ..) => { - let physical = Int64Array::from(("literal", [*val].as_slice())); - TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() + let physical = Int64Array::from((field_name, [*val].as_slice())); + TimestampArray::new(Field::new(field_name, self.get_type()), physical).into_series() } Self::Duration(val, ..) => { - let physical = Int64Array::from(("literal", [*val].as_slice())); - DurationArray::new(Field::new("literal", self.get_type()), physical).into_series() + let physical = Int64Array::from((field_name, [*val].as_slice())); + DurationArray::new(Field::new(field_name, self.get_type()), physical).into_series() } Self::Interval(val) => IntervalArray::from_values( - "literal", + field_name, std::iter::once((val.months, val.days, val.nanoseconds)), ) .into_series(), - Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Float64(val) => Float64Array::from((field_name, [*val].as_slice())).into_series(), Self::Decimal(val, ..) => { - let physical = Int128Array::from(("literal", [*val].as_slice())); - Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series() + let physical = Int128Array::from((field_name, [*val].as_slice())); + Decimal128Array::new(Field::new(field_name, self.get_type()), physical) + .into_series() } - Self::Series(series) => series.clone().rename("literal"), + Self::Series(series) => series.clone().rename(field_name), #[cfg(feature = "python")] - Self::Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Self::Python(val) => PythonArray::from((field_name, vec![val.0.clone()])).into_series(), Self::Struct(entries) => { let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); - let struct_field = Field::new("literal", struct_dtype); + let struct_field = Field::new(field_name, struct_dtype); - let values = entries.values().map(|v| v.to_series()).collect(); + let values = entries + .iter() + .map(|(field, value)| value.to_series_helper(Some(&field.name))) + .collect(); StructArray::new(struct_field, values, None).into_series() } } } - + pub fn to_series(&self) -> Series { + self.to_series_helper(None) + } pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { let display_sql_err = Err(io::Error::new( io::ErrorKind::Other, @@ -554,9 +622,12 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { #[cfg(test)] mod test { - use daft_core::prelude::*; + use common_error::DaftError; + use daft_core::{datatypes::IntervalValue, prelude::*}; use super::LiteralValue; + #[cfg(feature = "python")] + use crate::pyobj_serde::PyObjectWrapper; #[test] fn test_literals_to_series() { @@ -598,4 +669,117 @@ mod test { let actual = super::literals_to_series(&values); assert!(actual.is_err()); } + + #[test] + fn test_struct_literal_to_serials() { + let values = vec![LiteralValue::Int32(1), LiteralValue::Int64(2)]; + let fields = vec![ + Field::new("a", DataType::Int32), + Field::new("b", DataType::Int64), + ]; + let struct_literal = + LiteralValue::Struct(fields.into_iter().zip(values.into_iter()).collect()); + let series = struct_literal.to_series(); + assert_eq!(series.len(), 1); + assert_eq!( + series.data_type(), + &DataType::Struct(vec![ + Field::new("a", DataType::Int32), + Field::new("b", DataType::Int64), + ]) + ); + } + + #[test] + fn test_zero_literal_value() { + let type_and_expected_values = vec![ + (DataType::Null, LiteralValue::Null), + (DataType::Boolean, LiteralValue::Boolean(false)), + (DataType::Utf8, LiteralValue::Utf8("".to_string())), + (DataType::Binary, LiteralValue::Binary(vec![])), + (DataType::FixedSizeBinary(1), LiteralValue::Binary(vec![0])), + (DataType::Int32, LiteralValue::Int32(0)), + (DataType::UInt32, LiteralValue::UInt32(0)), + (DataType::Int64, LiteralValue::Int64(0)), + (DataType::UInt64, LiteralValue::UInt64(0)), + (DataType::Date, LiteralValue::Date(0)), + ( + DataType::Time(TimeUnit::Microseconds), + LiteralValue::Time(0, TimeUnit::Microseconds), + ), + ( + DataType::Timestamp(TimeUnit::Microseconds, Some("UTC".to_string())), + LiteralValue::Timestamp(0, TimeUnit::Microseconds, Some("UTC".to_string())), + ), + ( + DataType::Duration(TimeUnit::Microseconds), + LiteralValue::Duration(0, TimeUnit::Microseconds), + ), + (DataType::Float64, LiteralValue::Float64(0.0)), + (DataType::Decimal128(1, 1), LiteralValue::Decimal(0, 1, 1)), + ( + DataType::Interval, + LiteralValue::Interval(IntervalValue::new(0, 0, 0)), + ), + ( + DataType::List(Box::new(DataType::Int32)), + LiteralValue::Series(Series::empty("literal", &DataType::Int32)), + ), + #[cfg(feature = "python")] + (DataType::Python, { + use pyo3::prelude::*; + LiteralValue::Python(PyObjectWrapper(Python::with_gil(|py| py.None()))) + }), + ( + DataType::Struct(vec![ + Field::new("a", DataType::Int32), + Field::new("b", DataType::Int64), + ]), + LiteralValue::Struct( + vec![ + (Field::new("a", DataType::Int32), LiteralValue::Int32(0)), + (Field::new("b", DataType::Int64), LiteralValue::Int64(0)), + ] + .into_iter() + .collect(), + ), + ), + ]; + for (dt, expected) in type_and_expected_values { + let actual = LiteralValue::new_zero(&dt).unwrap(); + assert_eq!(expected, actual, "DataType: {:?}", dt); + } + + // fixed size list returns all size of null values + let fixed_size_list = DataType::FixedSizeList(Box::new(DataType::Int32), 4); + let actual = LiteralValue::new_zero(&fixed_size_list).unwrap(); + let array_arrow = actual.as_series().unwrap().to_arrow(); + // the get_type of series is the inner type + assert_eq!(DataType::Int32, actual.get_type()); + assert_eq!(4, array_arrow.len()); + assert_eq!(4, array_arrow.null_count()); + } + + #[test] + fn test_unsupported_zero_literal_value() { + let unsupported_types = vec![ + DataType::Int8, + DataType::UInt8, + DataType::Int16, + DataType::UInt16, + DataType::Float32, + DataType::Embedding(Box::new(DataType::Int32), 1), + DataType::Map { + key: Box::new(DataType::Int32), + value: Box::new(DataType::Int32), + }, + DataType::Image(None), + // others are omitted + ]; + for dt in unsupported_types { + let actual = LiteralValue::new_zero(&dt); + assert!(actual.is_err()); + assert!(matches!(actual.unwrap_err(), DaftError::TypeError(_))); + } + } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 0ba4ad8b92..5a09af1a63 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -174,6 +174,12 @@ pub fn lit(item: Bound) -> PyResult { } } +#[pyfunction] +pub fn zero_value(dType: PyDataType) -> PyResult { + let literal_val = LiteralValue::new_zero(&dType.dtype)?; + Ok(Expr::Literal(literal_val).into()) +} + // Create a UDF Expression using: // * `func` - a Python function that takes as input an ordered list of Python Series to execute the user's UDF. // * `expressions` - an ordered list of Expressions, each representing computation that will be performed, producing a Series to pass into `func` diff --git a/tests/cookbook/test_literals.py b/tests/cookbook/test_literals.py index f652136439..c1ec159fdb 100644 --- a/tests/cookbook/test_literals.py +++ b/tests/cookbook/test_literals.py @@ -2,7 +2,7 @@ import pandas as pd -from daft import DataType, col, lit +from daft import DataType, col, lit, zero_lit from tests.conftest import assert_df_equals @@ -54,3 +54,36 @@ def test_literal_column_computation_apply(daft_df, service_requests_csv_pd_df): daft_pd_df = daft_df.to_pandas() service_requests_csv_pd_df["literal_col"] = "bar" assert_df_equals(daft_pd_df, service_requests_csv_pd_df) + + +def test_literal_zero_value(daft_df, service_requests_csv_pd_df): + daft_df = daft_df.select( + "*", + zero_lit(DataType.null()).alias("zero_null"), + zero_lit(DataType.bool()).alias("zero_bool"), + zero_lit(DataType.string()).alias("zero_string"), + zero_lit(DataType.binary()).alias("zero_binary"), + zero_lit(DataType.int32()).alias("zero_int"), + zero_lit(DataType.uint32()).alias("zero_uint"), + zero_lit(DataType.int64()).alias("zero_int64"), + zero_lit(DataType.uint64()).alias("zero_uint64"), + zero_lit(DataType.float64()).alias("zero_float"), + zero_lit(DataType.timestamp("s", None)).alias("zero_timestamp"), + zero_lit(DataType.python()).alias("zero_python"), + zero_lit(DataType.struct({"foo": DataType.int32(), "bar": DataType.string()})).alias("zero_struct"), + ) + daft_pd_df = daft_df.to_pandas() + row_count = len(daft_pd_df) + service_requests_csv_pd_df["zero_null"] = None + service_requests_csv_pd_df["zero_bool"] = False + service_requests_csv_pd_df["zero_string"] = "" + service_requests_csv_pd_df["zero_binary"] = b"" + service_requests_csv_pd_df["zero_int"] = pd.array([0] * row_count, dtype="int32") + service_requests_csv_pd_df["zero_uint"] = pd.array([0] * row_count, dtype="uint32") + service_requests_csv_pd_df["zero_int64"] = pd.array([0] * row_count, dtype="int64") + service_requests_csv_pd_df["zero_uint64"] = pd.array([0] * row_count, dtype="uint64") + service_requests_csv_pd_df["zero_float"] = pd.array([0.0] * row_count, dtype="float64") + service_requests_csv_pd_df["zero_timestamp"] = pd.Timestamp("1970-01-01") + service_requests_csv_pd_df["zero_python"] = None + service_requests_csv_pd_df["zero_struct"] = pd.Series([{"foo": 0, "bar": ""}] * row_count, dtype="object") + assert_df_equals(daft_pd_df, service_requests_csv_pd_df)