From 50d9b805c7708808d7f3741e6908edbb06d25267 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 27 Jun 2024 18:36:54 -0700 Subject: [PATCH] [FEAT] Add struct get syntactic sugar (#2367) Adds the ability to query struct and map fields by using the dot syntax, such as `col("a.b")` turning into `col("a").struct.get("b")`. This PR also includes a minor refactor of agg expression checking and extraction, just moving it out of the builder and into the `resolve_expr` and `resolve_aggexpr` functions that also deal with the syntactic sugar. I changed this since we were talking about how brittle it would be to do the syntactic sugar conversion in the builder and realized that it applies to aggregations too. --- Cargo.lock | 1 + daft/daft.pyi | 1 + daft/dataframe/dataframe.py | 12 +- daft/expressions/expressions.py | 59 ++-- src/daft-dsl/Cargo.toml | 1 + src/daft-dsl/src/expr.rs | 285 +++++++++++++++++- src/daft-dsl/src/lib.rs | 4 + src/daft-dsl/src/python.rs | 6 + src/daft-plan/src/builder.rs | 125 +------- src/daft-plan/src/logical_ops/agg.rs | 23 +- src/daft-plan/src/logical_ops/explode.rs | 8 +- src/daft-plan/src/logical_ops/filter.rs | 7 +- src/daft-plan/src/logical_ops/join.rs | 16 +- src/daft-plan/src/logical_ops/pivot.rs | 24 +- src/daft-plan/src/logical_ops/project.rs | 16 +- src/daft-plan/src/logical_ops/repartition.rs | 17 +- src/daft-plan/src/logical_ops/sink.rs | 29 +- src/daft-plan/src/logical_ops/sort.rs | 17 +- src/daft-plan/src/logical_ops/unpivot.rs | 23 +- .../rules/push_down_projection.rs | 2 +- src/daft-plan/src/logical_plan.rs | 4 +- tests/dataframe/test_getter_sugar.py | 27 ++ 22 files changed, 479 insertions(+), 228 deletions(-) create mode 100644 tests/dataframe/test_getter_sugar.py diff --git a/Cargo.lock b/Cargo.lock index 73d8add41d..8eb889d91f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1668,6 +1668,7 @@ dependencies = [ "daft-io", "daft-sketch", "indexmap 2.2.6", + "itertools 0.11.0", "pyo3", "pyo3-log", "serde", diff --git a/daft/daft.pyi b/daft/daft.pyi index 29b62fdafe..d1310d509c 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1151,6 +1151,7 @@ def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ... +def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... class PySeries: @staticmethod diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index c1b479d4b2..75b3566b34 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -31,13 +31,7 @@ from daft.api_annotations import DataframePublicAPI from daft.context import get_context from daft.convert import InputListType -from daft.daft import ( - FileFormat, - IOConfig, - JoinStrategy, - JoinType, - ResourceRequest, -) +from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, ResourceRequest, resolve_expr from daft.dataframe.preview import DataFramePreview from daft.datatype import DataType from daft.errors import ExpressionTypeError @@ -769,8 +763,8 @@ def __getitem__(self, item: Union[slice, int, str, Iterable[Union[str, int]]]) - return result elif isinstance(item, str): schema = self._builder.schema() - field = schema[item] - return col(field.name) + expr, _ = resolve_expr(col(item)._expr, schema._schema) + return Expression._from_pyexpr(expr) elif isinstance(item, Iterable): schema = self._builder.schema() diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 0840758bb7..082b13649e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1735,27 +1735,48 @@ def get(self, key: Expression) -> Expression: """Retrieves the value for a key in a map column Example: - >>> import pyarrrow as pa + >>> import pyarrow as pa >>> import daft - >>> pa_array = pa.array([[(1, 2)],[],[(2,1)]], type=pa.map_(pa.int64(), pa.int64())) + >>> pa_array = pa.array([[("a", 1)],[],[("b",2)]], type=pa.map_(pa.string(), pa.int64())) >>> df = daft.from_arrow(pa.table({"map_col": pa_array})) - >>> df = df.with_column("1", df["map_col"].map.get(1)) - >>> df.show() - ╭───────────────────────────────────────┬───────╮ - │ map_col ┆ 1 │ - │ --- ┆ --- │ - │ Map[Struct[key: Int64, value: Int64]] ┆ Int64 │ - ╞═══════════════════════════════════════╪═══════╡ - │ [{key: 1, ┆ 2 │ - │ value: 2, ┆ │ - │ }] ┆ │ - ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ [] ┆ None │ - ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ [{key: 2, ┆ None │ - │ value: 1, ┆ │ - │ }] ┆ │ - ╰───────────────────────────────────────┴───────╯ + >>> df1 = df.with_column("a", df["map_col"].map.get("a")) + >>> df1.show() + ╭───────────┬───────╮ + │ map_col ┆ a │ + │ --- ┆ --- │ + │ Map[Utf8] ┆ Int64 │ + ╞═══════════╪═══════╡ + │ [{key: a, ┆ 1 │ + │ value: 1, ┆ │ + │ }] ┆ │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [] ┆ None │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [{key: b, ┆ None │ + │ value: 2, ┆ │ + │ }] ┆ │ + ╰───────────┴───────╯ + (Showing first 3 of 3 rows) + >>> + >>> # you may also use the "column.key" syntax to get map values + >>> df2 = df.with_column("b", df["map_col.b"]) + >>> df2.show() + ╭───────────┬───────╮ + │ map_col ┆ b │ + │ --- ┆ --- │ + │ Map[Utf8] ┆ Int64 │ + ╞═══════════╪═══════╡ + │ [{key: a, ┆ None │ + │ value: 1, ┆ │ + │ }] ┆ │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [] ┆ None │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [{key: b, ┆ 2 │ + │ value: 2, ┆ │ + │ }] ┆ │ + ╰───────────┴───────╯ + (Showing first 3 of 3 rows) Args: key: the key to retrieve diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index d6a9c0768a..d745015f7f 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -7,6 +7,7 @@ daft-core = {path = "../daft-core", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-sketch = {path = "../daft-sketch", default-features = false} indexmap = {workspace = true} +itertools = {workspace = true} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} serde = {workspace = true} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index f21062747d..a2a294a790 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -20,6 +20,8 @@ use common_error::{DaftError, DaftResult}; use serde::{Deserialize, Serialize}; use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashMap}, fmt::{Debug, Display, Formatter, Result}, io::{self, Write}, sync::Arc, @@ -202,7 +204,7 @@ impl AggExpr { Sum(_) => Sum(children[0].clone()), Mean(_) => Mean(children[0].clone()), Min(_) => Min(children[0].clone()), - Max(_) => Mean(children[0].clone()), + Max(_) => Max(children[0].clone()), AnyValue(_, ignore_nulls) => AnyValue(children[0].clone(), *ignore_nulls), List(_) => List(children[0].clone()), Concat(_) => Concat(children[0].clone()), @@ -345,6 +347,12 @@ impl AggExpr { } } +impl From<&AggExpr> for ExprRef { + fn from(agg_expr: &AggExpr) -> Self { + Arc::new(Expr::Agg(agg_expr.clone())) + } +} + impl AsRef for Expr { fn as_ref(&self) -> &Expr { self @@ -1029,6 +1037,197 @@ impl Operator { } } +/// Converts an expression with syntactic sugar into struct gets. +/// Does left-associative parsing to to resolve ambiguity. +/// +/// For example, if col("a.b.c") could be interpreted as either col("a.b").struct.get("c") +/// or col("a").struct.get("b.c"), this function will resolve it to col("a.b").struct.get("c"). +fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { + use common_treenode::{Transformed, TransformedResult, TreeNode}; + + #[derive(PartialEq, Eq)] + struct BfsState<'a> { + name: String, + expr: ExprRef, + field: &'a Field, + } + + impl Ord for BfsState<'_> { + fn cmp(&self, other: &Self) -> Ordering { + self.name.cmp(&other.name) + } + } + + impl PartialOrd for BfsState<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + let mut pq: BinaryHeap = BinaryHeap::new(); + + for field in schema.fields.values() { + pq.push(BfsState { + name: field.name.clone(), + expr: Arc::new(Expr::Column(field.name.clone().into())), + field, + }); + } + + let mut str_to_get_expr: HashMap = HashMap::new(); + + while let Some(BfsState { name, expr, field }) = pq.pop() { + if !str_to_get_expr.contains_key(&name) { + str_to_get_expr.insert(name.clone(), expr.clone()); + } + + if let DataType::Struct(children) = &field.dtype { + for child in children { + pq.push(BfsState { + name: format!("{}.{}", name, child.name), + expr: crate::functions::struct_::get(expr.clone(), &child.name), + field: child, + }); + } + } + } + + expr.transform(|e| match e.as_ref() { + Expr::Column(name) => str_to_get_expr + .get(name.as_ref()) + .ok_or(DaftError::ValueError(format!( + "Column not found in schema: {name}" + ))) + .map(|get_expr| match get_expr.as_ref() { + Expr::Column(_) => Transformed::no(e), + _ => Transformed::yes(get_expr.clone()), + }), + _ => Ok(Transformed::no(e)), + }) + .data() +} + +fn expr_has_agg(expr: &ExprRef) -> bool { + use Expr::*; + + match expr.as_ref() { + Agg(_) => true, + Column(_) | Literal(_) => false, + Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => expr_has_agg(e), + BinaryOp { left, right, .. } => expr_has_agg(left) || expr_has_agg(right), + Function { inputs, .. } => inputs.iter().any(expr_has_agg), + IsIn(l, r) | FillNull(l, r) => expr_has_agg(l) || expr_has_agg(r), + Between(v, l, u) => expr_has_agg(v) || expr_has_agg(l) || expr_has_agg(u), + IfElse { + if_true, + if_false, + predicate, + } => expr_has_agg(if_true) || expr_has_agg(if_false) || expr_has_agg(predicate), + } +} + +fn extract_agg_expr(expr: &Expr) -> DaftResult { + use Expr::*; + + match expr { + Agg(agg_expr) => Ok(agg_expr.clone()), + Function { func, inputs } => Ok(AggExpr::MapGroups { + func: func.clone(), + inputs: inputs.clone(), + }), + Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { + use AggExpr::*; + + // reorder expressions so that alias goes before agg + match agg_expr { + Count(e, count_mode) => Count(Alias(e, name.clone()).into(), count_mode), + Sum(e) => Sum(Alias(e, name.clone()).into()), + ApproxSketch(e) => ApproxSketch(Alias(e, name.clone()).into()), + ApproxPercentile(ApproxPercentileParams { + child: e, + percentiles, + force_list_output, + }) => ApproxPercentile(ApproxPercentileParams { + child: Alias(e, name.clone()).into(), + percentiles, + force_list_output, + }), + MergeSketch(e) => MergeSketch(Alias(e, name.clone()).into()), + Mean(e) => Mean(Alias(e, name.clone()).into()), + Min(e) => Min(Alias(e, name.clone()).into()), + Max(e) => Max(Alias(e, name.clone()).into()), + AnyValue(e, ignore_nulls) => AnyValue(Alias(e, name.clone()).into(), ignore_nulls), + List(e) => List(Alias(e, name.clone()).into()), + Concat(e) => Concat(Alias(e, name.clone()).into()), + MapGroups { func, inputs } => MapGroups { + func, + inputs: inputs + .into_iter() + .map(|input| input.alias(name.clone())) + .collect(), + }, + } + }), + // TODO(Kevin): Support a mix of aggregation and non-aggregation expressions + // as long as the final value always has a cardinality of 1. + _ => Err(DaftError::ValueError(format!( + "Expected aggregation expression, but got: {expr}" + ))), + } +} + +/// Resolves and validates the expression with a schema, returning the new expression and its field. +pub fn resolve_expr(expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRef, Field)> { + // TODO(Kevin): Support aggregation expressions everywhere + if expr_has_agg(&expr) { + return Err(DaftError::ValueError(format!( + "Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", + ))); + } + + let resolved_expr = substitute_expr_getter_sugar(expr, schema)?; + let resolved_field = resolved_expr.to_field(schema)?; + Ok((resolved_expr, resolved_field)) +} + +pub fn resolve_exprs( + exprs: Vec, + schema: &Schema, +) -> DaftResult<(Vec, Vec)> { + let resolved_iter = exprs.into_iter().map(|e| resolve_expr(e, schema)); + itertools::process_results(resolved_iter, |res| res.unzip()) +} + +/// Resolves and validates the expression with a schema, returning the extracted aggregation expression and its field. +pub fn resolve_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<(AggExpr, Field)> { + let agg_expr = extract_agg_expr(&expr)?; + + let has_nested_agg = agg_expr.children().iter().any(expr_has_agg); + + if has_nested_agg { + return Err(DaftError::ValueError(format!( + "Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" + ))); + } + + let resolved_children = agg_expr + .children() + .into_iter() + .map(|e| substitute_expr_getter_sugar(e, schema)) + .collect::>>()?; + let resolved_agg = agg_expr.with_new_children(resolved_children); + let resolved_field = resolved_agg.to_field(schema)?; + Ok((resolved_agg, resolved_field)) +} + +pub fn resolve_aggexprs( + exprs: Vec, + schema: &Schema, +) -> DaftResult<(Vec, Vec)> { + let resolved_iter = exprs.into_iter().map(|e| resolve_aggexpr(e, schema)); + itertools::process_results(resolved_iter, |res| res.unzip()) +} + #[cfg(test)] mod tests { @@ -1114,4 +1313,88 @@ mod tests { Ok(()) } + + #[test] + fn test_substitute_expr_getter_sugar() -> DaftResult<()> { + use crate::functions::struct_::get as struct_get; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); + assert!(matches!( + substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), + DaftError::ValueError(..) + )); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new("b", DataType::Int64)]), + )])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, + struct_get(col("a"), "b").alias("c") + ); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + )])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + ), + Field::new("a.b", DataType::Int64), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + col("a.b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), + ), + Field::new( + "a.b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + ), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(col("a.b"), "c") + ); + + Ok(()) + } } diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 0ae3ad090d..fb85b1e083 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -1,4 +1,6 @@ #![feature(let_chains)] +#![feature(if_let_guard)] + mod arithmetic; mod expr; pub mod functions; @@ -12,6 +14,7 @@ mod treenode; pub use common_treenode; pub use expr::binary_op; pub use expr::col; +pub use expr::{resolve_aggexpr, resolve_aggexprs, resolve_expr, resolve_exprs}; pub use expr::{AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator}; pub use lit::{lit, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] @@ -30,6 +33,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_wrapped(wrap_pyfunction!(python::series_lit))?; parent.add_wrapped(wrap_pyfunction!(python::udf))?; parent.add_wrapped(wrap_pyfunction!(python::eq))?; + parent.add_wrapped(wrap_pyfunction!(python::resolve_expr))?; Ok(()) } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index b26525ae42..cc7878374a 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -175,6 +175,12 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult { Ok(expr1.expr == expr2.expr) } +#[pyfunction] +pub fn resolve_expr(expr: &PyExpr, schema: &PySchema) -> PyResult<(PyExpr, PyField)> { + let (resolved_expr, field) = crate::resolve_expr(expr.expr.clone(), &schema.schema)?; + Ok((resolved_expr.into(), field.into())) +} + #[derive(FromPyObject)] pub enum ApproxPercentileInput { Single(f64), diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index d29ec01bfb..b16c1a42d7 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -14,13 +14,13 @@ use crate::{ source_info::SourceInfo, ResourceRequest, }; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use common_io_config::IOConfig; use daft_core::{ join::{JoinStrategy, JoinType}, schema::{Schema, SchemaRef}, }; -use daft_dsl::{col, ApproxPercentileParams, Expr, ExprRef}; +use daft_dsl::{col, ExprRef}; use daft_scan::{file_format::FileFormat, Pushdowns, ScanExternalInfo, ScanOperatorRef}; #[cfg(feature = "python")] @@ -50,101 +50,6 @@ impl LogicalPlanBuilder { } } -fn check_for_agg(expr: &ExprRef) -> bool { - use Expr::*; - - match expr.as_ref() { - Agg(_) => true, - Column(_) | Literal(_) => false, - Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => check_for_agg(e), - BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right), - Function { inputs, .. } => inputs.iter().any(check_for_agg), - IsIn(l, r) | FillNull(l, r) => check_for_agg(l) || check_for_agg(r), - Between(v, l, u) => check_for_agg(v) || check_for_agg(l) || check_for_agg(u), - IfElse { - if_true, - if_false, - predicate, - } => check_for_agg(if_true) || check_for_agg(if_false) || check_for_agg(predicate), - } -} - -fn err_if_agg(fn_name: &str, exprs: &Vec) -> DaftResult<()> { - for e in exprs { - if check_for_agg(e) { - return Err(DaftError::ValueError(format!( - "Aggregation expressions are not currently supported in {fn_name}: {e}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383", - fn_name=fn_name, - e=e - ))); - } - } - Ok(()) -} - -fn extract_agg_expr(expr: &Expr) -> DaftResult { - use Expr::*; - - match expr { - Agg(agg_expr) => Ok(agg_expr.clone()), - Function { func, inputs } => Ok(daft_dsl::AggExpr::MapGroups { - func: func.clone(), - inputs: inputs.clone(), - }), - Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { - use daft_dsl::AggExpr::*; - - // reorder expressions so that alias goes before agg - match agg_expr { - Count(e, count_mode) => Count(Alias(e, name.clone()).into(), count_mode), - Sum(e) => Sum(Alias(e, name.clone()).into()), - ApproxSketch(e) => ApproxSketch(Alias(e, name.clone()).into()), - ApproxPercentile(ApproxPercentileParams { - child: e, - percentiles, - force_list_output, - }) => ApproxPercentile(ApproxPercentileParams { - child: Alias(e, name.clone()).into(), - percentiles, - force_list_output, - }), - MergeSketch(e) => MergeSketch(Alias(e, name.clone()).into()), - Mean(e) => Mean(Alias(e, name.clone()).into()), - Min(e) => Min(Alias(e, name.clone()).into()), - Max(e) => Max(Alias(e, name.clone()).into()), - AnyValue(e, ignore_nulls) => AnyValue(Alias(e, name.clone()).into(), ignore_nulls), - List(e) => List(Alias(e, name.clone()).into()), - Concat(e) => Concat(Alias(e, name.clone()).into()), - MapGroups { func, inputs } => MapGroups { - func, - inputs: inputs - .into_iter() - .map(|input| input.alias(name.clone())) - .collect(), - }, - } - }), - // TODO(Kevin): Support a mix of aggregation and non-aggregation expressions - // as long as the final value always has a cardinality of 1. - _ => Err(DaftError::ValueError(format!( - "Expected aggregation expression, but got: {expr}" - ))), - } -} - -fn extract_and_check_agg_expr(expr: &Expr) -> DaftResult { - let agg_expr = extract_agg_expr(expr)?; - let has_nested_agg = agg_expr.children().iter().any(check_for_agg); - - if has_nested_agg { - Err(DaftError::ValueError(format!( - "Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383" - ))) - } else { - Ok(agg_expr) - } -} - impl LogicalPlanBuilder { #[cfg(feature = "python")] pub fn in_memory_scan( @@ -204,8 +109,6 @@ impl LogicalPlanBuilder { } pub fn select(&self, to_select: Vec) -> DaftResult { - err_if_agg("project", &to_select)?; - let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), to_select, Default::default())?.into(); Ok(logical_plan.into()) @@ -216,8 +119,6 @@ impl LogicalPlanBuilder { columns: Vec, resource_request: ResourceRequest, ) -> DaftResult { - err_if_agg("with_columns", &columns)?; - let fields = &self.schema().fields; let current_col_names = fields .iter() @@ -272,8 +173,6 @@ impl LogicalPlanBuilder { } pub fn filter(&self, predicate: ExprRef) -> DaftResult { - err_if_agg("filter", &vec![predicate.to_owned()])?; - let logical_plan: LogicalPlan = logical_ops::Filter::try_new(self.plan.clone(), predicate)?.into(); Ok(logical_plan.into()) @@ -286,8 +185,6 @@ impl LogicalPlanBuilder { } pub fn explode(&self, to_explode: Vec) -> DaftResult { - err_if_agg("explode", &to_explode)?; - let logical_plan: LogicalPlan = logical_ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(logical_plan.into()) @@ -332,8 +229,6 @@ impl LogicalPlanBuilder { } pub fn sort(&self, sort_by: Vec, descending: Vec) -> DaftResult { - err_if_agg("sort", &sort_by)?; - let logical_plan: LogicalPlan = logical_ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into(); Ok(logical_plan.into()) @@ -344,8 +239,6 @@ impl LogicalPlanBuilder { num_partitions: Option, partition_by: Vec, ) -> DaftResult { - err_if_agg("hash_repartition", &partition_by)?; - let logical_plan: LogicalPlan = logical_ops::Repartition::try_new( self.plan.clone(), RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), @@ -393,12 +286,6 @@ impl LogicalPlanBuilder { agg_exprs: Vec, groupby_exprs: Vec, ) -> DaftResult { - let agg_exprs = agg_exprs - .iter() - .map(|v| v.as_ref()) - .map(extract_and_check_agg_expr) - .collect::>>()?; - let logical_plan: LogicalPlan = logical_ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); Ok(logical_plan.into()) @@ -412,7 +299,6 @@ impl LogicalPlanBuilder { agg_expr: ExprRef, names: Vec, ) -> DaftResult { - let agg_expr = extract_and_check_agg_expr(agg_expr.as_ref())?; let pivot_logical_plan: LogicalPlan = logical_ops::Pivot::try_new( self.plan.clone(), group_by, @@ -433,9 +319,6 @@ impl LogicalPlanBuilder { join_type: JoinType, join_strategy: Option, ) -> DaftResult { - err_if_agg("join", &left_on)?; - err_if_agg("join", &right_on)?; - let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), right.plan.clone(), @@ -468,10 +351,6 @@ impl LogicalPlanBuilder { compression: Option, io_config: Option, ) -> DaftResult { - if let Some(partition_cols) = &partition_cols { - err_if_agg("table_write", partition_cols)?; - } - let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new( root_dir.into(), file_format, diff --git a/src/daft-plan/src/logical_ops/agg.rs b/src/daft-plan/src/logical_ops/agg.rs index a601540c5a..4c841db8fb 100644 --- a/src/daft-plan/src/logical_ops/agg.rs +++ b/src/daft-plan/src/logical_ops/agg.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use snafu::ResultExt; use daft_core::schema::{Schema, SchemaRef}; -use daft_dsl::{AggExpr, ExprRef}; +use daft_dsl::{resolve_aggexprs, resolve_exprs, AggExpr, ExprRef}; use crate::logical_plan::{self, CreationSnafu}; use crate::LogicalPlan; @@ -26,19 +26,18 @@ pub struct Aggregate { impl Aggregate { pub(crate) fn try_new( input: Arc, - aggregations: Vec, + aggregations: Vec, groupby: Vec, ) -> logical_plan::Result { - let output_schema = { - let upstream_schema = input.schema(); - let fields = groupby - .iter() - .map(|e| e.to_field(&upstream_schema)) - .chain(aggregations.iter().map(|ae| ae.to_field(&upstream_schema))) - .collect::>>() - .context(CreationSnafu)?; - Schema::new(fields).context(CreationSnafu)?.into() - }; + let upstream_schema = input.schema(); + let (groupby, groupby_fields) = + resolve_exprs(groupby, &upstream_schema).context(CreationSnafu)?; + let (aggregations, aggregation_fields) = + resolve_aggexprs(aggregations, &upstream_schema).context(CreationSnafu)?; + + let fields = [groupby_fields, aggregation_fields].concat(); + + let output_schema = Schema::new(fields).context(CreationSnafu)?.into(); Ok(Self { aggregations, diff --git a/src/daft-plan/src/logical_ops/explode.rs b/src/daft-plan/src/logical_ops/explode.rs index aac1a9632f..fc2b56f709 100644 --- a/src/daft-plan/src/logical_ops/explode.rs +++ b/src/daft-plan/src/logical_ops/explode.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use daft_core::schema::{Schema, SchemaRef}; -use daft_dsl::ExprRef; +use daft_dsl::{resolve_exprs, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -24,13 +24,16 @@ impl Explode { input: Arc, to_explode: Vec, ) -> logical_plan::Result { + let upstream_schema = input.schema(); + + let (to_explode, _) = resolve_exprs(to_explode, &upstream_schema).context(CreationSnafu)?; + let explode_exprs = to_explode .iter() .cloned() .map(daft_dsl::functions::list::explode) .collect::>(); let exploded_schema = { - let upstream_schema = input.schema(); let explode_schema = { let explode_fields = explode_exprs .iter() @@ -47,6 +50,7 @@ impl Explode { .collect::>(); Schema::new(fields).context(CreationSnafu)?.into() }; + Ok(Self { input, to_explode, diff --git a/src/daft-plan/src/logical_ops/filter.rs b/src/daft-plan/src/logical_ops/filter.rs index 4b6bff3e3f..91f2f250ba 100644 --- a/src/daft-plan/src/logical_ops/filter.rs +++ b/src/daft-plan/src/logical_ops/filter.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use daft_core::DataType; -use daft_dsl::ExprRef; +use daft_dsl::{resolve_expr, ExprRef}; use snafu::ResultExt; use crate::logical_plan::{CreationSnafu, Result}; @@ -18,9 +18,8 @@ pub struct Filter { impl Filter { pub(crate) fn try_new(input: Arc, predicate: ExprRef) -> Result { - let field = predicate - .to_field(input.schema().as_ref()) - .context(CreationSnafu)?; + let (predicate, field) = resolve_expr(predicate, &input.schema()).context(CreationSnafu)?; + if !matches!(field.dtype, DataType::Boolean) { return Err(DaftError::ValueError(format!( "Expected expression {predicate} to resolve to type Boolean, but received: {}", diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 8fe2bad6fb..523f408c52 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -1,12 +1,12 @@ use std::{collections::HashSet, sync::Arc}; -use common_error::{DaftError, DaftResult}; +use common_error::DaftError; use daft_core::{ join::{JoinStrategy, JoinType}, schema::{hash_index_map, Schema, SchemaRef}, DataType, }; -use daft_dsl::ExprRef; +use daft_dsl::{resolve_exprs, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -54,12 +54,12 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { - for (on_exprs, schema) in [(&left_on, left.schema()), (&right_on, right.schema())] { - let on_fields = on_exprs - .iter() - .map(|e| e.to_field(schema.as_ref())) - .collect::>>() - .context(CreationSnafu)?; + let (left_on, left_fields) = + resolve_exprs(left_on, &left.schema()).context(CreationSnafu)?; + let (right_on, right_fields) = + resolve_exprs(right_on, &right.schema()).context(CreationSnafu)?; + + for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { let on_schema = Schema::new(on_fields).context(CreationSnafu)?; for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) { if matches!(field.dtype, DataType::Null) { diff --git a/src/daft-plan/src/logical_ops/pivot.rs b/src/daft-plan/src/logical_ops/pivot.rs index 3ab95a22b5..ae9115dcdb 100644 --- a/src/daft-plan/src/logical_ops/pivot.rs +++ b/src/daft-plan/src/logical_ops/pivot.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use common_error::DaftResult; use daft_core::datatypes::Field; use itertools::Itertools; use snafu::ResultExt; use daft_core::schema::{Schema, SchemaRef}; -use daft_dsl::{AggExpr, ExprRef}; +use daft_dsl::{resolve_aggexpr, resolve_expr, resolve_exprs, AggExpr, ExprRef}; use crate::logical_plan::{self, CreationSnafu}; use crate::LogicalPlan; @@ -28,19 +27,20 @@ impl Pivot { group_by: Vec, pivot_column: ExprRef, value_column: ExprRef, - aggregation: AggExpr, + aggregation: ExprRef, names: Vec, ) -> logical_plan::Result { + let upstream_schema = input.schema(); + let (group_by, group_by_fields) = + resolve_exprs(group_by, &upstream_schema).context(CreationSnafu)?; + let (pivot_column, _) = + resolve_expr(pivot_column, &upstream_schema).context(CreationSnafu)?; + let (value_column, value_col_field) = + resolve_expr(value_column, &upstream_schema).context(CreationSnafu)?; + let (aggregation, _) = + resolve_aggexpr(aggregation, &upstream_schema).context(CreationSnafu)?; + let output_schema = { - let upstream_schema = input.schema(); - let group_by_fields = group_by - .iter() - .map(|e| e.to_field(&upstream_schema)) - .collect::>>() - .context(CreationSnafu)?; - let value_col_field = value_column - .to_field(&upstream_schema) - .context(CreationSnafu)?; let value_col_dtype = value_col_field.dtype; let pivot_value_fields = names .iter() diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 86905a41a9..07881ccf1c 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use daft_core::datatypes::FieldID; use daft_core::schema::{Schema, SchemaRef}; -use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef}; +use daft_dsl::{optimization, resolve_exprs, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use snafu::ResultExt; @@ -26,19 +26,15 @@ impl Project { projection: Vec, resource_request: ResourceRequest, ) -> Result { + let (projection, fields) = + resolve_exprs(projection, &input.schema()).context(CreationSnafu)?; + // Factor the projection and see if there are any substitutions to factor out. let (factored_input, factored_projection) = Self::try_factor_subexpressions(input, projection, &resource_request)?; - let upstream_schema = factored_input.schema(); - let projected_schema = { - let fields = factored_projection - .iter() - .map(|e| e.to_field(&upstream_schema)) - .collect::>>() - .context(CreationSnafu)?; - Schema::new(fields).context(CreationSnafu)?.into() - }; + let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + Ok(Self { input: factored_input, projection: factored_projection, diff --git a/src/daft-plan/src/logical_ops/repartition.rs b/src/daft-plan/src/logical_ops/repartition.rs index f531f5082b..a537d7b476 100644 --- a/src/daft-plan/src/logical_ops/repartition.rs +++ b/src/daft-plan/src/logical_ops/repartition.rs @@ -1,8 +1,12 @@ use std::sync::Arc; use common_error::DaftResult; +use daft_dsl::resolve_exprs; -use crate::{partitioning::RepartitionSpec, LogicalPlan}; +use crate::{ + partitioning::{HashRepartitionConfig, RepartitionSpec}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Repartition { @@ -16,6 +20,17 @@ impl Repartition { input: Arc, repartition_spec: RepartitionSpec, ) -> DaftResult { + let repartition_spec = match repartition_spec { + RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => { + let (resolved_by, _) = resolve_exprs(by, &input.schema())?; + RepartitionSpec::Hash(HashRepartitionConfig { + num_partitions, + by: resolved_by, + }) + } + RepartitionSpec::Random(_) | RepartitionSpec::IntoPartitions(_) => repartition_spec, + }; + Ok(Self { input, repartition_spec, diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index b9e347bfa5..dcf4bde192 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -5,8 +5,9 @@ use daft_core::{ datatypes::Field, schema::{Schema, SchemaRef}, }; +use daft_dsl::resolve_exprs; -use crate::{sink_info::SinkInfo, LogicalPlan}; +use crate::{sink_info::SinkInfo, LogicalPlan, OutputFileInfo}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Sink { @@ -21,6 +22,32 @@ impl Sink { pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { let schema = input.schema(); + let sink_info = match sink_info.as_ref() { + SinkInfo::OutputFileInfo(OutputFileInfo { + root_dir, + file_format, + partition_cols, + compression, + io_config, + }) => { + let resolved_partition_cols = partition_cols + .clone() + .map(|cols| { + resolve_exprs(cols, &schema).map(|(resolved_cols, _)| resolved_cols) + }) + .transpose()?; + + Arc::new(SinkInfo::OutputFileInfo(OutputFileInfo { + root_dir: root_dir.clone(), + file_format: file_format.clone(), + partition_cols: resolved_partition_cols, + compression: compression.clone(), + io_config: io_config.clone(), + })) + } + _ => sink_info, + }; + let fields = match sink_info.as_ref() { SinkInfo::OutputFileInfo(output_file_info) => { let mut fields = vec![Field::new("path", daft_core::DataType::Utf8)]; diff --git a/src/daft-plan/src/logical_ops/sort.rs b/src/daft-plan/src/logical_ops/sort.rs index 16196ef0a8..adc25022f7 100644 --- a/src/daft-plan/src/logical_ops/sort.rs +++ b/src/daft-plan/src/logical_ops/sort.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::schema::Schema; use daft_core::DataType; -use daft_dsl::ExprRef; +use daft_dsl::{resolve_exprs, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -31,15 +31,12 @@ impl Sort { )) .context(CreationSnafu); } - let upstream_schema = input.schema(); - let sort_by_resolved_schema = { - let sort_by_fields = sort_by - .iter() - .map(|e| e.to_field(&upstream_schema)) - .collect::>>() - .context(CreationSnafu)?; - Schema::new(sort_by_fields).context(CreationSnafu)? - }; + + let (sort_by, sort_by_fields) = + resolve_exprs(sort_by, &input.schema()).context(CreationSnafu)?; + + let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?; + for (field, expr) in sort_by_resolved_schema.fields.values().zip(sort_by.iter()) { // Disallow sorting by null, binary, and boolean columns. // TODO(Clark): This is a port of an existing constraint, we should look at relaxing this. diff --git a/src/daft-plan/src/logical_ops/unpivot.rs b/src/daft-plan/src/logical_ops/unpivot.rs index 06a6ad168a..1815388ee1 100644 --- a/src/daft-plan/src/logical_ops/unpivot.rs +++ b/src/daft-plan/src/logical_ops/unpivot.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +use common_error::DaftError; use daft_core::{ datatypes::Field, schema::{Schema, SchemaRef}, utils::supertype::try_get_supertype, DataType, }; -use daft_dsl::ExprRef; +use daft_dsl::{resolve_exprs, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -43,11 +43,8 @@ impl Unpivot { } let input_schema = input.schema(); - let values_fields = values - .iter() - .map(|e| e.to_field(&input_schema)) - .collect::>>() - .context(CreationSnafu)?; + let (values, values_fields) = + resolve_exprs(values, &input_schema).context(CreationSnafu)?; let value_dtype = values_fields .iter() @@ -59,12 +56,12 @@ impl Unpivot { let variable_field = Field::new(variable_name, DataType::Utf8); let value_field = Field::new(value_name, value_dtype); - let output_fields = ids - .iter() - .map(|e| e.to_field(&input_schema)) - .chain(vec![Ok(variable_field), Ok(value_field)]) - .collect::>>() - .context(CreationSnafu)?; + let (ids, ids_fields) = resolve_exprs(ids, &input_schema).context(CreationSnafu)?; + + let output_fields = ids_fields + .into_iter() + .chain([variable_field, value_field]) + .collect::>(); let output_schema = Schema::new(output_fields).context(CreationSnafu)?.into(); diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index db91617d15..11285e69b5 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -214,7 +214,7 @@ impl PushDownProjection { .aggregations .iter() .filter(|&e| required_columns.contains(e.name())) - .cloned() + .map(|ae| ae.into()) .collect::>(); if pruned_aggregate_exprs.len() < aggregate.aggregations.len() { diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 3e8969958b..cb04907447 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -194,8 +194,8 @@ impl LogicalPlan { Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone()).unwrap()), Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::try_new(input.clone(), scheme_config.clone()).unwrap()), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), - Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), - Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation,names,..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.clone(), names.clone()).unwrap()), + Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.iter().map(|ae| ae.into()).collect(), groupby.clone()).unwrap()), + Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), diff --git a/tests/dataframe/test_getter_sugar.py b/tests/dataframe/test_getter_sugar.py new file mode 100644 index 0000000000..3d6abc1a64 --- /dev/null +++ b/tests/dataframe/test_getter_sugar.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import daft + + +def test_getter_sugar(): + df = daft.from_pydict({"a": [{"x": 1, "y": "one"}, {"x": 2, "y": "two"}]}) + + df = df.select("a.x", "a.y") + + assert df.to_pydict() == {"x": [1, 2], "y": ["one", "two"]} + + +def test_getter_sugar_nested(): + df = daft.from_pydict({"a": [{"b": {"c": 1}}, {"b": {"c": 2}}]}) + + df = df.select("a.b", "a.b.c") + + assert df.to_pydict() == {"b": [{"c": 1}, {"c": 2}], "c": [1, 2]} + + +def test_getter_sugar_nested_multiple(): + df = daft.from_pydict({"a.b": [1, 2, 3], "a": [{"b": 1}, {"b": 2}, {"b": 3}]}) + + df = df.select("a", "a.b") + + assert df.to_pydict() == {"a": [{"b": 1}, {"b": 2}, {"b": 3}], "a.b": [1, 2, 3]}