diff --git a/Cargo.lock b/Cargo.lock index d83b23aa6c..8bf5e58b39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2082,6 +2082,7 @@ dependencies = [ name = "daft-sql" version = "0.3.0-dev0" dependencies = [ + "common-daft-config", "common-error", "daft-core", "daft-dsl", diff --git a/daft/daft.pyi b/daft/daft.pyi index 832ff3fdea..78d30c8d26 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1242,7 +1242,7 @@ def minhash( ngram_size: int, seed: int = 1, ) -> PyExpr: ... -def sql(sql: str, catalog: PyCatalog) -> LogicalPlanBuilder: ... +def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... @@ -1625,6 +1625,7 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... @staticmethod def table_scan(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... + def with_planning_config(self, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def select(self, to_select: list[PyExpr]) -> LogicalPlanBuilder: ... def with_columns(self, columns: list[PyExpr]) -> LogicalPlanBuilder: ... def exclude(self, to_exclude: list[str]) -> LogicalPlanBuilder: ... diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 89f19458fd..b2717df2f6 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -1,8 +1,10 @@ from __future__ import annotations +import functools import pathlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable +from daft.context import get_context from daft.daft import ( CountMode, FileFormat, @@ -26,6 +28,25 @@ ) +def _apply_daft_planning_config_to_initializer(classmethod_func: Callable[..., LogicalPlanBuilder]): + """Decorator to be applied to any @classmethod instantiation method on LogicalPlanBuilder + + This decorator ensures that the current DaftPlanningConfig is applied to the instantiated LogicalPlanBuilder + """ + + @functools.wraps(classmethod_func) + def wrapper(cls: type[LogicalPlanBuilder], *args, **kwargs): + instantiated_logical_plan_builder = classmethod_func(cls, *args, **kwargs) + + # Parametrize the builder with the current DaftPlanningConfig + inner = instantiated_logical_plan_builder._builder + inner = inner.with_planning_config(get_context().daft_planning_config) + + return cls(inner) + + return wrapper + + class LogicalPlanBuilder: """ A logical plan builder for the Daft DataFrame. @@ -91,6 +112,7 @@ def optimize(self) -> LogicalPlanBuilder: return LogicalPlanBuilder(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_in_memory_scan( cls, partition: PartitionCacheEntry, @@ -110,6 +132,7 @@ def from_in_memory_scan( return cls(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_tabular_scan( cls, *, diff --git a/daft/sql/sql.py b/daft/sql/sql.py index d1b85af976..a5334a28b9 100644 --- a/daft/sql/sql.py +++ b/daft/sql/sql.py @@ -1,6 +1,7 @@ # isort: dont-add-import: from __future__ import annotations from daft.api_annotations import PublicAPI +from daft.context import get_context from daft.daft import PyCatalog as _PyCatalog from daft.daft import sql as _sql from daft.daft import sql_expr as _sql_expr @@ -45,6 +46,8 @@ def sql(sql: str, catalog: SQLCatalog) -> DataFrame: Returns: DataFrame: Dataframe containing the results of the query """ + planning_config = get_context().daft_planning_config + _py_catalog = catalog._catalog - _py_logical = _sql(sql, _py_catalog) + _py_logical = _sql(sql, _py_catalog, planning_config) return DataFrame(LogicalPlanBuilder(_py_logical)) diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 10d2b75cff..c423d3aaaa 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -10,7 +10,7 @@ pub const BOLD_TABLE_HEADERS_IN_DISPLAY: &str = "DAFT_BOLD_TABLE_HEADERS"; /// 1. Creation of a Dataframe including any file listing and schema inference that needs to happen. Note /// that this does not include the actual scan, which is taken care of by the DaftExecutionConfig. /// 2. Building of logical plan nodes -#[derive(Clone, Serialize, Deserialize, Default)] +#[derive(Clone, Serialize, Deserialize, Default, Debug)] pub struct DaftPlanningConfig { pub default_io_config: IOConfig, pub enable_actor_pool_projections: bool, @@ -111,6 +111,9 @@ mod python; #[cfg(feature = "python")] pub use python::PyDaftExecutionConfig; +#[cfg(feature = "python")] +pub use python::PyDaftPlanningConfig; + #[cfg(feature = "python")] use pyo3::prelude::*; diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index cd719e7a50..9a5719c652 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -5,7 +5,7 @@ use std::{ use crate::{ logical_ops, - logical_optimization::Optimizer, + logical_optimization::{Optimizer, OptimizerConfig}, logical_plan::LogicalPlan, partitioning::{ HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec, @@ -14,6 +14,7 @@ use crate::{ source_info::SourceInfo, LogicalPlanRef, }; +use common_daft_config::DaftPlanningConfig; use common_display::mermaid::MermaidDisplayOptions; use common_error::DaftResult; use common_io_config::IOConfig; @@ -29,6 +30,7 @@ use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, crate::source_info::InMemoryInfo, + common_daft_config::PyDaftPlanningConfig, daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, daft_scan::python::pylib::ScanOperatorHandle, @@ -44,36 +46,47 @@ use { pub struct LogicalPlanBuilder { // The current root of the logical plan in this builder. pub plan: Arc, + config: Option>, } impl LogicalPlanBuilder { - pub fn new(plan: Arc) -> Self { - Self { plan } + pub fn new(plan: Arc, config: Option>) -> Self { + Self { plan, config } } } -impl From for LogicalPlanBuilder { - fn from(plan: LogicalPlan) -> Self { +impl From<&LogicalPlanBuilder> for LogicalPlanBuilder { + fn from(builder: &LogicalPlanBuilder) -> Self { Self { - plan: Arc::new(plan), + plan: builder.plan.clone(), + config: builder.config.clone(), } } } -impl From for LogicalPlanBuilder { - fn from(plan: LogicalPlanRef) -> Self { - Self { plan: plan.clone() } +impl From for LogicalPlanRef { + fn from(value: LogicalPlanBuilder) -> Self { + value.plan } } -impl From<&LogicalPlanBuilder> for LogicalPlanBuilder { - fn from(builder: &LogicalPlanBuilder) -> Self { - Self { - plan: builder.plan.clone(), - } + +impl From<&LogicalPlanBuilder> for LogicalPlanRef { + fn from(value: &LogicalPlanBuilder) -> Self { + value.plan.clone() } } impl LogicalPlanBuilder { + /// Replace the LogicalPlanBuilder's plan with the provided plan + pub fn with_new_plan>>(&self, plan: LP) -> Self { + Self::new(plan.into(), self.config.clone()) + } + + /// Parametrize the LogicalPlanBuilder with a DaftPlanningConfig + pub fn with_config(&self, config: Arc) -> Self { + Self::new(self.plan.clone(), Some(config)) + } + #[cfg(feature = "python")] pub fn in_memory_scan( partition_key: &str, @@ -94,7 +107,7 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); - Ok(logical_plan.into()) + Ok(LogicalPlanBuilder::new(logical_plan.into(), None)) } pub fn table_scan( @@ -128,13 +141,13 @@ impl LogicalPlanBuilder { }; let logical_plan: LogicalPlan = logical_ops::Source::new(output_schema, source_info.into()).into(); - Ok(logical_plan.into()) + Ok(LogicalPlanBuilder::new(logical_plan.into(), None)) } pub fn select(&self, to_select: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), to_select)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn with_columns(&self, columns: Vec) -> DaftResult { @@ -167,7 +180,7 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), exprs)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn exclude(&self, to_exclude: Vec) -> DaftResult { @@ -188,25 +201,25 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), exprs)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn filter(&self, predicate: ExprRef) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Filter::try_new(self.plan.clone(), predicate)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn limit(&self, limit: i64, eager: bool) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Limit::new(self.plan.clone(), limit, eager).into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn explode(&self, to_explode: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn unpivot( @@ -244,13 +257,13 @@ impl LogicalPlanBuilder { value_name, )? .into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn sort(&self, sort_by: Vec, descending: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn hash_repartition( @@ -263,7 +276,7 @@ impl LogicalPlanBuilder { RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), )? .into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn random_shuffle(&self, num_partitions: Option) -> DaftResult { @@ -272,7 +285,7 @@ impl LogicalPlanBuilder { RepartitionSpec::Random(RandomShuffleConfig::new(num_partitions)), )? .into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn into_partitions(&self, num_partitions: usize) -> DaftResult { @@ -281,12 +294,12 @@ impl LogicalPlanBuilder { RepartitionSpec::IntoPartitions(IntoPartitionsConfig::new(num_partitions)), )? .into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn distinct(&self) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Distinct::new(self.plan.clone()).into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn sample( @@ -297,7 +310,7 @@ impl LogicalPlanBuilder { ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Sample::new(self.plan.clone(), fraction, with_replacement, seed).into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn aggregate( @@ -307,7 +320,7 @@ impl LogicalPlanBuilder { ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn pivot( @@ -327,10 +340,10 @@ impl LogicalPlanBuilder { names, )? .into(); - Ok(pivot_logical_plan.into()) + Ok(self.with_new_plan(pivot_logical_plan)) } - pub fn join>( + pub fn join>( &self, right: Right, left_on: Vec, @@ -340,26 +353,26 @@ impl LogicalPlanBuilder { ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), - right.into().plan.clone(), + right.into().clone(), left_on, right_on, join_type, join_strategy, )? .into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn concat(&self, other: &Self) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Concat::try_new(self.plan.clone(), other.plan.clone())?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn table_write( @@ -380,7 +393,7 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } #[cfg(feature = "python")] @@ -409,7 +422,7 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } #[cfg(feature = "python")] @@ -437,7 +450,7 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } #[cfg(feature = "python")] @@ -464,7 +477,7 @@ impl LogicalPlanBuilder { let logical_plan: LogicalPlan = logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); - Ok(logical_plan.into()) + Ok(self.with_new_plan(logical_plan)) } pub fn build(&self) -> Arc { @@ -536,6 +549,13 @@ impl PyLogicalPlanBuilder { Ok(LogicalPlanBuilder::table_scan(scan_operator.into(), None)?.into()) } + pub fn with_planning_config( + &self, + daft_planning_config: PyDaftPlanningConfig, + ) -> PyResult { + Ok(self.builder.with_config(daft_planning_config.config).into()) + } + pub fn select(&self, to_select: Vec) -> PyResult { Ok(self.builder.select(pyexprs_to_exprs(to_select))?.into()) } @@ -777,7 +797,12 @@ impl PyLogicalPlanBuilder { /// Optimize the underlying logical plan, returning a new plan builder containing the optimized plan. pub fn optimize(&self, py: Python) -> PyResult { py.allow_threads(|| { - let optimizer = Optimizer::new(Default::default()); + // Create optimizer + let default_optimizer_config: OptimizerConfig = Default::default(); + let optimizer_config = OptimizerConfig { enable_actor_pool_projections: self.builder.config.as_ref().map(|planning_cfg| planning_cfg.enable_actor_pool_projections).unwrap_or(default_optimizer_config.enable_actor_pool_projections), ..default_optimizer_config }; + let optimizer = Optimizer::new(optimizer_config); + + // Run LogicalPlan optimizations let unoptimized_plan = self.builder.build(); let optimized_plan = optimizer.optimize( unoptimized_plan, @@ -800,7 +825,8 @@ impl PyLogicalPlanBuilder { } }, )?; - let builder = LogicalPlanBuilder::new(optimized_plan); + + let builder = LogicalPlanBuilder::new(optimized_plan, self.builder.config.clone()); Ok(builder.into()) }) } diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index 5af4c9034b..543a728892 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -146,11 +146,11 @@ mod test { #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1()) + let subplan = LogicalPlanBuilder::new(plan_1(), None) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2()) + let subplan2 = LogicalPlanBuilder::new(plan_2(), None) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -160,7 +160,7 @@ mod test { .sort(vec![col("last_name")], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan) + let plan = LogicalPlanBuilder::new(subplan, None) .join( subplan2, vec![col("id")], @@ -217,11 +217,11 @@ Project1 --> Limit0 #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display_simple() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1()) + let subplan = LogicalPlanBuilder::new(plan_1(), None) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2()) + let subplan2 = LogicalPlanBuilder::new(plan_2(), None) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -231,7 +231,7 @@ Project1 --> Limit0 .sort(vec![col("last_name")], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan) + let plan = LogicalPlanBuilder::new(subplan, None) .join( subplan2, vec![col("id")], diff --git a/src/daft-plan/src/logical_optimization/mod.rs b/src/daft-plan/src/logical_optimization/mod.rs index 82f349e55f..d270947f19 100644 --- a/src/daft-plan/src/logical_optimization/mod.rs +++ b/src/daft-plan/src/logical_optimization/mod.rs @@ -4,5 +4,5 @@ mod rules; #[cfg(test)] mod test; -pub use optimizer::Optimizer; +pub use optimizer::{Optimizer, OptimizerConfig}; pub use rules::Transformed; diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index db88c468c1..916f4a1f25 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -909,6 +909,7 @@ mod tests { } /// Projection<-ActorPoolProject prunes columns from the ActorPoolProject + #[cfg(not(feature = "python"))] #[test] fn test_projection_pushdown_into_actorpoolproject() -> DaftResult<()> { use crate::logical_ops::ActorPoolProject; @@ -1042,6 +1043,7 @@ mod tests { } /// Projection<-ActorPoolProject prunes ActorPoolProject entirely if the stateful projection column is pruned + #[cfg(not(feature = "python"))] #[test] fn test_projection_pushdown_into_actorpoolproject_completely_removed() -> DaftResult<()> { use crate::logical_ops::ActorPoolProject; diff --git a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs index 3b86f39f92..3dd9568f56 100644 --- a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs @@ -575,6 +575,7 @@ mod tests { ) } + #[cfg(not(feature = "python"))] fn create_stateful_udf(inputs: Vec) -> ExprRef { Expr::Function { func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { @@ -598,6 +599,7 @@ mod tests { } } + #[cfg(not(feature = "python"))] #[test] fn test_with_column_stateful_udf_happypath() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", daft_core::DataType::Utf8)]); @@ -624,6 +626,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_multiple_with_column_parallel() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ @@ -712,6 +715,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_multiple_with_column_parallel_common_subtree_eliminated() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", daft_core::DataType::Utf8)]); @@ -779,6 +783,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_multiple_with_column_serial() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", daft_core::DataType::Utf8)]); @@ -855,6 +860,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_multiple_with_column_serial_multiarg() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ @@ -958,6 +964,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_multiple_with_column_serial_multiarg_with_intermediate_stateless() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ @@ -1078,6 +1085,7 @@ mod tests { Ok(()) } + #[cfg(not(feature = "python"))] #[test] fn test_nested_with_column_same_names() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", daft_core::DataType::Int64)]); diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 8ef727eedf..b727e4763d 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -1088,7 +1088,7 @@ mod tests { let logical_plan = force_repartition(logical_plan, left_partitions)? .select(vec![col("a"), col("b"), col("c").alias("dataL")])? .join( - &join_node, + join_node, vec![col("a"), col("b")], vec![col("a"), col("b")], daft_core::JoinType::Inner, diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 75810e41a7..0191c3aa28 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 2c5214adba..dbbbcc021b 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -167,7 +167,7 @@ mod tests { let sql = "select test as a from tbl1"; let plan = planner.plan_sql(sql).unwrap(); - let expected = LogicalPlanBuilder::new(tbl_1) + let expected = LogicalPlanBuilder::new(tbl_1, None) .select(vec![col("test").alias("a")]) .unwrap() .build(); @@ -179,7 +179,7 @@ mod tests { let sql = "select test as a from tbl1 where test = 'a'"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1) + let expected = LogicalPlanBuilder::new(tbl_1, None) .filter(col("test").eq(lit("a")))? .select(vec![col("test").alias("a")])? .build(); @@ -192,7 +192,7 @@ mod tests { let sql = "select test as a from tbl1 limit 10"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1) + let expected = LogicalPlanBuilder::new(tbl_1, None) .select(vec![col("test").alias("a")])? .limit(10, true)? .build(); @@ -206,7 +206,7 @@ mod tests { let sql = "select utf8 from tbl1 order by utf8 desc"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1) + let expected = LogicalPlanBuilder::new(tbl_1, None) .select(vec![col("utf8")])? .sort(vec![col("utf8")], vec![true])? .build(); @@ -217,7 +217,7 @@ mod tests { #[rstest] fn test_cast(mut planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { - let builder = LogicalPlanBuilder::new(tbl_1); + let builder = LogicalPlanBuilder::new(tbl_1, None); let cases = vec![ ( "select bool::text from tbl1", @@ -255,7 +255,7 @@ mod tests { ) -> SQLPlannerResult<()> { let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_2) + let expected = LogicalPlanBuilder::new(tbl_2, None) .join( tbl_3, vec![col("id")], diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index d0c233363a..f47b8f19b1 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -407,7 +407,7 @@ impl SQLPlanner { .catalog .get_table(&table_name) .ok_or_else(|| PlannerError::table_not_found(table_name.clone()))?; - let plan_builder = LogicalPlanBuilder::new(plan); + let plan_builder = LogicalPlanBuilder::new(plan, None); Ok(Relation::new(plan_builder, table_name)) } _ => todo!(), diff --git a/src/daft-sql/src/python.rs b/src/daft-sql/src/python.rs index 7835b6bde7..7dd5556878 100644 --- a/src/daft-sql/src/python.rs +++ b/src/daft-sql/src/python.rs @@ -1,3 +1,4 @@ +use common_daft_config::PyDaftPlanningConfig; use daft_dsl::python::PyExpr; use daft_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; use pyo3::prelude::*; @@ -5,10 +6,14 @@ use pyo3::prelude::*; use crate::{catalog::SQLCatalog, planner::SQLPlanner}; #[pyfunction] -pub fn sql(sql: &str, catalog: PyCatalog) -> PyResult { +pub fn sql( + sql: &str, + catalog: PyCatalog, + daft_planning_config: PyDaftPlanningConfig, +) -> PyResult { let mut planner = SQLPlanner::new(catalog.catalog); let plan = planner.plan_sql(sql)?; - Ok(LogicalPlanBuilder::new(plan).into()) + Ok(LogicalPlanBuilder::new(plan, Some(daft_planning_config.config)).into()) } #[pyfunction]