From 76e256a7400fb2af1df34d8fe7f3d6e5ec1d0cad Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 19 Oct 2023 05:52:16 +0900 Subject: [PATCH] [PERF] Add "eager mode" to limits and use in .show() (#1498) 1. Adds "eager mode" to limits (when enabled, the global limit physical plan will only yield one in-flight partition at a time instead of allowing for batched execution) 2. Changes our `limit_pushdown` rule to keep the logical limit node around after performing limit pushdowns into the Source node: * This is needed because the logical limit node will be translated into a Physical global limit * Refactored some of the changes introduced in #1476 : translation of Source nodes into the physical plan no longer "implicitly" creates a physical limit node 3. Use an iterator of Tables in `.show()` instead of relying on `.collect()` 4. Adds a max buffer size to the results buffer in our RayRunner - this is so that we can bound the number of results to be buffered to `1` and prevent runaway asynchronous execution by the RayRunner which would otherwise run to completion This PR significantly speeds up show because: 1. Not all partitions need to be materialized (since we don't run a `.collect()`) 2. "Eager" mode results in fewer wasted read tasks (and also gets rid of our `TaskCancelledException` messages) 3. The RayRunner no longer executes all the way to completion, and is instead bounded by a prefetching max size of 1 --------- Co-authored-by: Jay Chia --- daft/daft.pyi | 2 +- daft/dataframe/dataframe.py | 29 +++-- daft/execution/physical_plan.py | 7 ++ daft/execution/physical_plan_factory.py | 1 + daft/logical/builder.py | 2 +- daft/logical/logical_plan.py | 11 +- daft/logical/optimizer.py | 2 +- daft/logical/rust_logical_plan.py | 4 +- daft/runners/pyrunner.py | 11 +- daft/runners/ray_runner.py | 14 ++- daft/runners/runner.py | 20 +++- src/daft-plan/src/builder.rs | 9 +- src/daft-plan/src/logical_ops/limit.rs | 11 +- src/daft-plan/src/logical_plan.rs | 2 +- .../src/optimization/rules/push_down_limit.rs | 110 ++++++++++-------- src/daft-plan/src/physical_ops/limit.rs | 9 +- src/daft-plan/src/physical_plan.rs | 44 ++----- src/daft-plan/src/planner.rs | 7 +- 18 files changed, 167 insertions(+), 128 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index 786d00e374..dffd96aa44 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -720,7 +720,7 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... def project(self, projection: list[PyExpr], resource_request: ResourceRequest) -> LogicalPlanBuilder: ... def filter(self, predicate: PyExpr) -> LogicalPlanBuilder: ... - def limit(self, limit: int) -> LogicalPlanBuilder: ... + def limit(self, limit: int, eager: bool) -> LogicalPlanBuilder: ... def explode(self, to_explode: list[PyExpr]) -> LogicalPlanBuilder: ... def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ... def repartition( diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 54030c0553..6fa959e43c 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -169,19 +169,24 @@ def show(self, n: int = 8) -> "DataFrameDisplay": Returns: DataFrameDisplay: object that has a rich tabular display """ - df = self - df = df.limit(n) - df.collect(num_preview_rows=None) - collected_preview = df._preview - assert collected_preview is not None - + builder = self._builder.limit(n, eager=True) + + # Iteratively retrieve partitions until enough data has been materialized + tables = [] + seen = 0 + for table in get_context().runner().run_iter_tables(builder, results_buffer_size=1): + tables.append(table) + seen += len(table) + if seen >= n: + break + + preview_partition = Table.concat(tables) + preview_partition = preview_partition if len(preview_partition) <= n else preview_partition.slice(0, n) preview = DataFramePreview( - preview_partition=collected_preview.preview_partition, - # Override dataframe_num_rows=None, because we do not know - # the size of the entire (un-limited) dataframe when showing + preview_partition=preview_partition, + # We do not know the size of the entire (un-limited) dataframe when showing dataframe_num_rows=None, ) - return DataFrameDisplay(preview, self.schema(), num_rows=n) @DataframePublicAPI @@ -586,11 +591,13 @@ def limit(self, num: int) -> "DataFrame": Args: num (int): maximum rows to allow. + eager (bool): whether to maximize for latency (time to first result) by eagerly executing + only one partition at a time, or throughput by executing multiple limits at a time Returns: DataFrame: Limited DataFrame """ - builder = self._builder.limit(num) + builder = self._builder.limit(num, eager=False) return DataFrame(builder) @DataframePublicAPI diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index e4bc572bc6..4bf3f8ff95 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -277,6 +277,7 @@ def local_limit( def global_limit( child_plan: InProgressPhysicalPlan[PartitionT], limit_rows: int, + eager: bool, num_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: """Return the first n rows from the `child_plan`.""" @@ -342,6 +343,12 @@ def global_limit( yield None continue + # If running in eager mode, only allow one task in flight + if eager and len(materializations) > 0: + logger.debug(f"global_limit blocking on eager execution of: {materializations[0]}") + yield None + continue + # Execute a single child partition. try: child_step = child_plan.send(remaining_rows) if started else next(child_plan) diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py index c90cd148ea..2211e5929a 100644 --- a/daft/execution/physical_plan_factory.py +++ b/daft/execution/physical_plan_factory.py @@ -104,6 +104,7 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> return physical_plan.global_limit( child_plan=child_plan, limit_rows=node._num, + eager=node._eager, num_partitions=node.num_partitions(), ) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 78beb92a88..40ae5e421b 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -100,7 +100,7 @@ def filter(self, predicate: Expression) -> LogicalPlanBuilder: pass @abstractmethod - def limit(self, num_rows: int) -> LogicalPlanBuilder: + def limit(self, num_rows: int, eager: bool) -> LogicalPlanBuilder: pass @abstractmethod diff --git a/daft/logical/logical_plan.py b/daft/logical/logical_plan.py index dbba21bc07..338f512605 100644 --- a/daft/logical/logical_plan.py +++ b/daft/logical/logical_plan.py @@ -150,9 +150,9 @@ def project( def filter(self, predicate: Expression): return Filter(self._plan, ExpressionsProjection([predicate])).to_builder() - def limit(self, num_rows: int) -> LogicalPlanBuilder: + def limit(self, num_rows: int, eager: bool) -> LogicalPlanBuilder: local_limit = LocalLimit(self._plan, num=num_rows) - plan = GlobalLimit(local_limit, num=num_rows) + plan = GlobalLimit(local_limit, num=num_rows, eager=eager) return plan.to_builder() def explode(self, explode_expressions: list[Expression]) -> PyLogicalPlanBuilder: @@ -828,17 +828,18 @@ def rebuild(self) -> LogicalPlan: class GlobalLimit(UnaryNode): - def __init__(self, input: LogicalPlan, num: int) -> None: + def __init__(self, input: LogicalPlan, num: int, eager: bool) -> None: super().__init__(input.schema(), partition_spec=input.partition_spec(), op_level=OpLevel.GLOBAL) self._register_child(input) self._num = num + self._eager = eager def __repr__(self) -> str: return self._repr_helper(num=self._num) def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan: assert len(new_children) == 1 - return GlobalLimit(new_children[0], self._num) + return GlobalLimit(new_children[0], self._num, self._eager) def required_columns(self) -> list[set[str]]: return [set()] @@ -850,7 +851,7 @@ def _local_eq(self, other: Any) -> bool: return isinstance(other, GlobalLimit) and self.schema() == other.schema() and self._num == other._num def rebuild(self) -> LogicalPlan: - return GlobalLimit(input=self._children()[0].rebuild(), num=self._num) + return GlobalLimit(input=self._children()[0].rebuild(), num=self._num, eager=self._eager) class LocalCount(UnaryNode): diff --git a/daft/logical/optimizer.py b/daft/logical/optimizer.py index 6b4db7537a..97210f03de 100644 --- a/daft/logical/optimizer.py +++ b/daft/logical/optimizer.py @@ -457,7 +457,7 @@ def _push_down_global_limit_into_unary_node(self, parent: GlobalLimit, child: Un """ logger.debug(f"pushing {parent} into {child}") grandchild = child._children()[0] - return child.copy_with_new_children([GlobalLimit(grandchild, num=parent._num)]) + return child.copy_with_new_children([GlobalLimit(grandchild, num=parent._num, eager=parent._eager)]) @property def _supported_unary_nodes(self) -> set[type[LogicalPlan]]: diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index a3bc920be2..3b48e24b4b 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -82,8 +82,8 @@ def filter(self, predicate: Expression) -> RustLogicalPlanBuilder: builder = self._builder.filter(predicate._expr) return RustLogicalPlanBuilder(builder) - def limit(self, num_rows: int) -> RustLogicalPlanBuilder: - builder = self._builder.limit(num_rows) + def limit(self, num_rows: int, eager: bool) -> RustLogicalPlanBuilder: + builder = self._builder.limit(num_rows, eager) return RustLogicalPlanBuilder(builder) def explode(self, explode_expressions: list[Expression]) -> RustLogicalPlanBuilder: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 37f285ec82..df0321f60a 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -139,7 +139,12 @@ def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: pset_entry = self.put_partition_set_into_cache(result_pset) return pset_entry - def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[Table]: + def run_iter( + self, + builder: LogicalPlanBuilder, + # NOTE: PyRunner does not run any async execution, so it ignores `results_buffer_size` which is essentially 0 + results_buffer_size: int | None = None, + ) -> Iterator[Table]: # Optimize the logical plan. builder = builder.optimize() # Finalize the logical plan and get a physical plan scheduler for translating the @@ -157,8 +162,8 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[Table]: partitions_gen = self._physical_plan_to_partitions(tasks) yield from partitions_gen - def run_iter_tables(self, builder: LogicalPlanBuilder) -> Iterator[Table]: - return self.run_iter(builder) + def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]: + return self.run_iter(builder, results_buffer_size=results_buffer_size) def _physical_plan_to_partitions(self, plan: physical_plan.MaterializedPhysicalPlan) -> Iterator[Table]: inflight_tasks: dict[str, PartitionTask] = dict() diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index c4b97a1738..5b36cfacfe 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -4,7 +4,6 @@ import threading import time import uuid -from collections import defaultdict from dataclasses import dataclass from datetime import datetime from queue import Queue @@ -412,7 +411,7 @@ def __init__(self, max_task_backlog: int | None) -> None: self.reserved_cores = 0 self.threads_by_df: dict[str, threading.Thread] = dict() - self.results_by_df: dict[str, Queue] = defaultdict(Queue) + self.results_by_df: dict[str, Queue] = {} def next(self, result_uuid: str) -> ray.ObjectRef | StopIteration: # Case: thread is terminated and no longer exists. @@ -435,7 +434,10 @@ def run_plan( plan_scheduler: PhysicalPlanScheduler, psets: dict[str, ray.ObjectRef], result_uuid: str, + results_buffer_size: int | None = None, ) -> None: + self.results_by_df[result_uuid] = Queue(maxsize=results_buffer_size or -1) + t = threading.Thread( target=self._run_plan, name=result_uuid, @@ -624,7 +626,7 @@ def __init__( max_task_backlog=max_task_backlog, ) - def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[ray.ObjectRef]: + def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[ray.ObjectRef]: # Optimize the logical plan. builder = builder.optimize() # Finalize the logical plan and get a physical plan scheduler for translating the @@ -643,6 +645,7 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[ray.ObjectRef]: plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, + results_buffer_size=results_buffer_size, ) ) @@ -651,6 +654,7 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[ray.ObjectRef]: plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, + results_buffer_size=results_buffer_size, ) while True: @@ -663,8 +667,8 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[ray.ObjectRef]: return yield result - def run_iter_tables(self, builder: LogicalPlanBuilder) -> Iterator[Table]: - for ref in self.run_iter(builder): + def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]: + for ref in self.run_iter(builder, results_buffer_size=results_buffer_size): yield ray.get(ref) def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: diff --git a/daft/runners/runner.py b/daft/runners/runner.py index 24128a9840..528f3b3337 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -34,11 +34,23 @@ def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: ... @abstractmethod - def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[PartitionT]: - """Similar to run(), but yield the individual partitions as they are completed.""" + def run_iter(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[PartitionT]: + """Similar to run(), but yield the individual partitions as they are completed. + + Args: + builder: the builder for the LogicalPlan that is to be executed + results_buffer_size: if the plan is executed asynchronously, this is the maximum size of the number of results + that can be buffered before execution should pause and wait. + """ ... @abstractmethod - def run_iter_tables(self, builder: LogicalPlanBuilder) -> Iterator[Table]: - """Similar to run_iter(), but always dereference and yield Table objects.""" + def run_iter_tables(self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None) -> Iterator[Table]: + """Similar to run_iter(), but always dereference and yield Table objects. + + Args: + builder: the builder for the LogicalPlan that is to be executed + results_buffer_size: if the plan is executed asynchronously, this is the maximum size of the number of results + that can be buffered before execution should pause and wait. + """ ... diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 142504825b..005fcabdd0 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -119,8 +119,9 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } - pub fn limit(&self, limit: i64) -> DaftResult { - let logical_plan: LogicalPlan = logical_ops::Limit::new(self.plan.clone(), limit).into(); + pub fn limit(&self, limit: i64, eager: bool) -> DaftResult { + let logical_plan: LogicalPlan = + logical_ops::Limit::new(self.plan.clone(), limit, eager).into(); Ok(logical_plan.into()) } @@ -317,8 +318,8 @@ impl PyLogicalPlanBuilder { Ok(self.builder.filter(predicate.expr)?.into()) } - pub fn limit(&self, limit: i64) -> PyResult { - Ok(self.builder.limit(limit)?.into()) + pub fn limit(&self, limit: i64, eager: bool) -> PyResult { + Ok(self.builder.limit(limit, eager)?.into()) } pub fn explode(&self, to_explode: Vec) -> PyResult { diff --git a/src/daft-plan/src/logical_ops/limit.rs b/src/daft-plan/src/logical_ops/limit.rs index 4425e6db3d..4d91ee4a84 100644 --- a/src/daft-plan/src/logical_ops/limit.rs +++ b/src/daft-plan/src/logical_ops/limit.rs @@ -8,10 +8,17 @@ pub struct Limit { pub input: Arc, // Limit on number of rows. pub limit: i64, + // Whether to send tasks in waves (maximize throughput) or + // eagerly one-at-a-time (maximize time-to-first-result) + pub eager: bool, } impl Limit { - pub(crate) fn new(input: Arc, limit: i64) -> Self { - Self { input, limit } + pub(crate) fn new(input: Arc, limit: i64, eager: bool) -> Self { + Self { + input, + limit, + eager, + } } } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 454e2e1116..5fba68d491 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -226,7 +226,7 @@ impl LogicalPlan { input.clone(), projection.clone(), resource_request.clone(), ).unwrap()), Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::try_new(input.clone(), predicate.clone()).unwrap()), - Self::Limit(Limit { limit, .. }) => Self::Limit(Limit::new(input.clone(), *limit)), + Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Sort(Sort { sort_by, descending, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone()).unwrap()), Self::Repartition(Repartition { num_partitions, partition_by, scheme, .. }) => Self::Repartition(Repartition::new(input.clone(), *num_partitions, partition_by.clone(), scheme.clone())), diff --git a/src/daft-plan/src/optimization/rules/push_down_limit.rs b/src/daft-plan/src/optimization/rules/push_down_limit.rs index 233d0390d9..e4baa1f0c1 100644 --- a/src/daft-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/optimization/rules/push_down_limit.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftResult; -use crate::{source_info::SourceInfo, LogicalPlan}; +use crate::{logical_ops::Limit as LogicalLimit, source_info::SourceInfo, LogicalPlan}; use super::{ApplyOrder, OptimizerRule, Transformed}; @@ -22,44 +22,48 @@ impl OptimizerRule for PushDownLimit { } fn try_optimize(&self, plan: Arc) -> DaftResult>> { - let limit = match plan.as_ref() { - LogicalPlan::Limit(limit) => limit, - _ => return Ok(Transformed::No(plan)), - }; - let child_plan = limit.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Repartition(_) | LogicalPlan::Coalesce(_) | LogicalPlan::Project(_) => { - // Naive commuting with unary ops. - // - // Limit-UnaryOp -> UnaryOp-Limit - let new_limit = plan.with_new_children(&[child_plan.children()[0].clone()]); - child_plan.with_new_children(&[new_limit]) - } - LogicalPlan::Source(source) => { - // Push limit into source. - // - // Limit-Source -> Source[with_limit] - - // Limit pushdown is only supported for external sources. - if !matches!(source.source_info.as_ref(), SourceInfo::ExternalInfo(_)) { - return Ok(Transformed::No(plan)); - } - let row_limit = limit.limit as usize; - // If source already has limit and the existing limit is less than the new limit, unlink the - // Limit node from the plan and leave the Source node untouched. - if let Some(existing_source_limit) = source.limit && existing_source_limit <= row_limit { - // We directly clone the Limit child rather than creating a new Arc on child_plan to elide - // an extra Arc. - limit.input.clone() - } else { - // Push limit into Source. - let new_source: LogicalPlan = source.with_limit(Some(row_limit)).into(); - new_source.into() + match plan.as_ref() { + LogicalPlan::Limit(LogicalLimit { input, limit, .. }) => { + let limit = *limit as usize; + match input.as_ref() { + // Naive commuting with unary ops. + // + // Limit-UnaryOp -> UnaryOp-Limit + LogicalPlan::Repartition(_) + | LogicalPlan::Coalesce(_) + | LogicalPlan::Project(_) => { + let new_limit = plan.with_new_children(&[input.children()[0].clone()]); + Ok(Transformed::Yes(input.with_new_children(&[new_limit]))) + } + // Push limit into source as a "local" limit. + // + // Limit-Source -> Limit-Source[with_limit] + LogicalPlan::Source(source) => { + match (source.source_info.as_ref(), source.limit) { + // Limit pushdown is not supported for in-memory sources. + #[cfg(feature = "python")] + (SourceInfo::InMemoryInfo(_), _) => Ok(Transformed::No(plan)), + // Do not pushdown if Source node is already more limited than `limit` + (SourceInfo::ExternalInfo(_), Some(existing_source_limit)) + if (existing_source_limit <= limit) => + { + Ok(Transformed::No(plan)) + } + // Pushdown limit into the Source node as a "local" limit + (SourceInfo::ExternalInfo(_), _) => { + let new_source = + LogicalPlan::Source(source.with_limit(Some(limit))).into(); + let limit_with_local_limited_source = + plan.with_new_children(&[new_source]); + Ok(Transformed::Yes(limit_with_local_limited_source)) + } + } + } + _ => Ok(Transformed::No(plan)), } } - _ => return Ok(Transformed::No(plan)), - }; - Ok(Transformed::Yes(new_plan)) + _ => Ok(Transformed::No(plan)), + } } } @@ -116,10 +120,11 @@ mod tests { Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), ]) - .limit(5)? + .limit(5, false)? .build(); let expected = "\ - Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; + Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -136,10 +141,11 @@ mod tests { ], Some(3), ) - .limit(5)? + .limit(5, false)? .build(); let expected = "\ - Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 3"; + Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 3"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -156,10 +162,11 @@ mod tests { ], Some(10), ) - .limit(5)? + .limit(5, false)? .build(); let expected = "\ - Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; + Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -173,7 +180,7 @@ mod tests { let py_obj = Python::with_gil(|py| py.None()); let schema: Arc = Schema::new(vec![Field::new("a", DataType::Int64)])?.into(); let plan = LogicalPlanBuilder::in_memory_scan("foo", py_obj, schema, Default::default())? - .limit(5)? + .limit(5, false)? .build(); let expected = "\ Limit: 5\ @@ -192,11 +199,12 @@ mod tests { Field::new("b", DataType::Utf8), ]) .repartition(1, vec![col("a")], PartitionScheme::Hash)? - .limit(5)? + .limit(5, false)? .build(); let expected = "\ Repartition: Scheme = Hash, Number of partitions = 1, Partition by = col(a)\ - \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; + \n Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -211,11 +219,12 @@ mod tests { Field::new("b", DataType::Utf8), ]) .coalesce(1)? - .limit(5)? + .limit(5, false)? .build(); let expected = "\ Coalesce: To = 1\ - \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; + \n Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -230,11 +239,12 @@ mod tests { Field::new("b", DataType::Utf8), ]) .project(vec![col("a")], Default::default())? - .limit(5)? + .limit(5, false)? .build(); let expected = "\ Project: col(a), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ - \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; + \n Limit: 5\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/physical_ops/limit.rs b/src/daft-plan/src/physical_ops/limit.rs index d96c952ac7..c64346530c 100644 --- a/src/daft-plan/src/physical_ops/limit.rs +++ b/src/daft-plan/src/physical_ops/limit.rs @@ -8,14 +8,21 @@ pub struct Limit { // Upstream node. pub input: Arc, pub limit: i64, + pub eager: bool, pub num_partitions: usize, } impl Limit { - pub(crate) fn new(input: Arc, limit: i64, num_partitions: usize) -> Self { + pub(crate) fn new( + input: Arc, + limit: i64, + eager: bool, + num_partitions: usize, + ) -> Self { Self { input, limit, + eager, num_partitions, } } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 28711586bb..b66388be59 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -6,7 +6,6 @@ use { ExternalInfo, FileFormat, FileFormatConfig, FileInfos, InMemoryInfo, PyFileFormatConfig, PyStorageConfig, StorageConfig, }, - PartitionSpec, }, daft_core::python::schema::PySchema, daft_core::schema::SchemaRef, @@ -111,7 +110,6 @@ fn tabular_scan( file_infos: &Arc, file_format_config: &Arc, storage_config: &Arc, - partition_spec: &PartitionSpec, limit: &Option, is_ray_runner: bool, ) -> PyResult { @@ -140,16 +138,7 @@ fn tabular_scan( is_ray_runner, ))?; - if let Some(limit) = limit { - apply_limit( - py, - py_iter.into(), - *limit as i64, - partition_spec.num_partitions, - ) - } else { - Ok(py_iter.into()) - } + Ok(py_iter.into()) } #[cfg(feature = "python")] @@ -181,23 +170,6 @@ fn tabular_write( Ok(py_iter.into()) } -#[cfg(feature = "python")] -fn apply_limit( - py: Python<'_>, - upstream_iter: PyObject, - limit: i64, - num_partitions: usize, -) -> PyResult { - let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; - let local_limit_iter = py_physical_plan - .getattr(pyo3::intern!(py, "local_limit"))? - .call1((upstream_iter, limit))?; - let global_limit_iter = py_physical_plan - .getattr(pyo3::intern!(py, "global_limit"))? - .call1((local_limit_iter, limit, num_partitions))?; - Ok(global_limit_iter.into()) -} - #[cfg(feature = "python")] impl PhysicalPlan { pub fn to_partition_tasks( @@ -231,7 +203,6 @@ impl PhysicalPlan { storage_config, .. }, - partition_spec, limit, .. }) => tabular_scan( @@ -241,7 +212,6 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, - partition_spec, limit, is_ray_runner, ), @@ -255,7 +225,6 @@ impl PhysicalPlan { storage_config, .. }, - partition_spec, limit, .. }) => tabular_scan( @@ -265,7 +234,6 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, - partition_spec, limit, is_ray_runner, ), @@ -279,7 +247,6 @@ impl PhysicalPlan { storage_config, .. }, - partition_spec, limit, .. }) => tabular_scan( @@ -289,7 +256,6 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, - partition_spec, limit, is_ray_runner, ), @@ -337,10 +303,16 @@ impl PhysicalPlan { PhysicalPlan::Limit(Limit { input, limit, + eager, num_partitions, }) => { let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; - apply_limit(py, upstream_iter, *limit, *num_partitions) + let py_physical_plan = + py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; + let global_limit_iter = py_physical_plan + .getattr(pyo3::intern!(py, "global_limit"))? + .call1((upstream_iter, *limit, *eager, *num_partitions))?; + Ok(global_limit_iter.into()) } PhysicalPlan::Explode(Explode { input, to_explode }) => { let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 31b2f5e0fb..bf35db6cda 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -102,11 +102,16 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { predicate.clone(), ))) } - LogicalPlan::Limit(LogicalLimit { input, limit }) => { + LogicalPlan::Limit(LogicalLimit { + input, + limit, + eager, + }) => { let input_physical = plan(input)?; Ok(PhysicalPlan::Limit(Limit::new( input_physical.into(), *limit, + *eager, logical_plan.partition_spec().num_partitions, ))) }