diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 9766c530ac..b11c28b0a7 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -6,6 +6,7 @@ use { ExternalInfo, FileFormat, FileFormatConfig, FileInfos, InMemoryInfo, PyFileFormatConfig, PyStorageConfig, StorageConfig, }, + PartitionSpec, }, daft_core::python::schema::PySchema, daft_core::schema::SchemaRef, @@ -110,6 +111,7 @@ fn tabular_scan( file_infos: &Arc, file_format_config: &Arc, storage_config: &Arc, + partition_spec: &PartitionSpec, limit: &Option, is_ray_runner: bool, ) -> PyResult { @@ -131,7 +133,17 @@ fn tabular_scan( *limit, is_ray_runner, ))?; - Ok(py_iter.into()) + + if let Some(limit) = limit { + apply_limit( + py, + py_iter.into(), + *limit as i64, + partition_spec.num_partitions, + ) + } else { + Ok(py_iter.into()) + } } #[cfg(feature = "python")] @@ -163,6 +175,23 @@ fn tabular_write( Ok(py_iter.into()) } +#[cfg(feature = "python")] +fn apply_limit( + py: Python<'_>, + upstream_iter: PyObject, + limit: i64, + num_partitions: usize, +) -> PyResult { + let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; + let local_limit_iter = py_physical_plan + .getattr(pyo3::intern!(py, "local_limit"))? + .call1((upstream_iter, limit))?; + let global_limit_iter = py_physical_plan + .getattr(pyo3::intern!(py, "global_limit"))? + .call1((local_limit_iter, limit, num_partitions))?; + Ok(global_limit_iter.into()) +} + #[cfg(feature = "python")] impl PhysicalPlan { pub fn to_partition_tasks( @@ -196,6 +225,7 @@ impl PhysicalPlan { storage_config, .. }, + partition_spec, limit, .. }) => tabular_scan( @@ -205,6 +235,7 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, + partition_spec, limit, is_ray_runner, ), @@ -218,6 +249,7 @@ impl PhysicalPlan { storage_config, .. }, + partition_spec, limit, .. }) => tabular_scan( @@ -227,6 +259,7 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, + partition_spec, limit, is_ray_runner, ), @@ -240,6 +273,7 @@ impl PhysicalPlan { storage_config, .. }, + partition_spec, limit, .. }) => tabular_scan( @@ -249,6 +283,7 @@ impl PhysicalPlan { file_infos, file_format_config, storage_config, + partition_spec, limit, is_ray_runner, ), @@ -299,15 +334,7 @@ impl PhysicalPlan { num_partitions, }) => { let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; - let py_physical_plan = - py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; - let local_limit_iter = py_physical_plan - .getattr(pyo3::intern!(py, "local_limit"))? - .call1((upstream_iter, *limit))?; - let global_limit_iter = py_physical_plan - .getattr(pyo3::intern!(py, "global_limit"))? - .call1((local_limit_iter, *limit, *num_partitions))?; - Ok(global_limit_iter.into()) + apply_limit(py, upstream_iter, *limit, *num_partitions) } PhysicalPlan::Explode(Explode { input, to_explode }) => { let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?;