From 664848b1c7b287dc653e0a43b79bfc5c62665956 Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Fri, 2 Aug 2024 17:30:52 -0700 Subject: [PATCH] Implement PyRunner execution of ActorPoolProject nodes --- daft/daft.pyi | 3 + daft/execution/execution_step.py | 26 +++ daft/execution/physical_plan.py | 75 ++++++++- daft/execution/rust_physical_plan_shim.py | 4 - daft/expressions/expressions.py | 5 +- daft/pickle/cloudpickle.py | 1 + daft/pickle/cloudpickle_fast.py | 1 + daft/runners/pyrunner.py | 154 +++++++++++++++++- daft/runners/ray_runner.py | 8 + daft/runners/runner.py | 26 +++ src/common/resource-request/src/lib.rs | 12 ++ src/daft-dsl/src/functions/python/mod.rs | 65 +++++++- src/daft-dsl/src/functions/python/udf.rs | 106 +++++++----- .../functions/python/udf_runtime_binding.rs | 70 ++++++++ src/daft-dsl/src/lib.rs | 2 + src/daft-dsl/src/python.rs | 18 ++ src/daft-scheduler/src/scheduler.rs | 31 +--- tests/actor_pool/__init__.py | 0 tests/actor_pool/test_pyactor_pool.py | 74 +++++++++ 19 files changed, 589 insertions(+), 92 deletions(-) create mode 100644 src/daft-dsl/src/functions/python/udf_runtime_binding.rs create mode 100644 tests/actor_pool/__init__.py create mode 100644 tests/actor_pool/test_pyactor_pool.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 2882d003a8..804951ef37 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -179,6 +179,7 @@ class ResourceRequest: def with_num_cpus(self, num_cpus: float | None) -> ResourceRequest: ... def with_num_gpus(self, num_gpus: float | None) -> ResourceRequest: ... def with_memory_bytes(self, memory_bytes: int | None) -> ResourceRequest: ... + def __mul__(self, factor: float) -> ResourceRequest: ... def __add__(self, other: ResourceRequest) -> ResourceRequest: ... def __repr__(self) -> str: ... def __eq__(self, other: ResourceRequest) -> bool: ... # type: ignore[override] @@ -1200,6 +1201,8 @@ def stateful_udf( batch_size: int | None, concurrency: int | None, ) -> PyExpr: ... +def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ... +def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ... diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 9f25bc04d1..d3b672bb63 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -42,6 +42,11 @@ class PartitionTask(Generic[PartitionT]): num_results: int stage_id: int partial_metadatas: list[PartialPartitionMetadata] + + # Indicates that this PartitionTask must be executed on the executor with the supplied ID + # This is used when a specific executor (e.g. an Actor pool) must be provisioned and used for the task + actor_pool_id: str | None + _id: int = field(default_factory=lambda: next(ID_GEN)) def id(self) -> str: @@ -86,6 +91,7 @@ def __init__( inputs: list[PartitionT], partial_metadatas: list[PartialPartitionMetadata] | None, resource_request: ResourceRequest = ResourceRequest(), + actor_pool_id: str | None = None, ) -> None: self.inputs = inputs if partial_metadatas is not None: @@ -95,6 +101,7 @@ def __init__( self.resource_request: ResourceRequest = resource_request self.instructions: list[Instruction] = list() self.num_results = len(inputs) + self.actor_pool_id = actor_pool_id def add_instruction( self, @@ -132,6 +139,7 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa num_results=1, resource_request=resource_request_final_cpu, partial_metadatas=self.partial_metadatas, + actor_pool_id=self.actor_pool_id, ) def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]: @@ -152,6 +160,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart num_results=self.num_results, resource_request=resource_request_final_cpu, partial_metadatas=self.partial_metadatas, + actor_pool_id=self.actor_pool_id, ) def __str__(self) -> str: @@ -527,6 +536,23 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ] +@dataclass(frozen=True) +class StatefulUDFProject(SingleOutputInstruction): + projection: ExpressionsProjection + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + raise NotImplementedError("UDFProject instruction cannot be run from outside an Actor. Please file an issue.") + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + return [ + PartialPartitionMetadata( + num_rows=None, # UDFs can potentially change cardinality + size_bytes=None, + boundaries=None, # TODO: figure out if the stateful UDF projection changes boundaries + ) + ] + + def _prune_boundaries(boundaries: Boundaries, projection: ExpressionsProjection) -> Boundaries | None: """ If projection expression is a nontrivial computation (i.e. not a direct col() reference and not an alias) on top of a boundary diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index d0d529075a..947c37f961 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -56,8 +56,6 @@ from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties - from daft.udf import PartialStatefulUDF - # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]] @@ -204,11 +202,80 @@ def pipeline_instruction( def actor_pool_project( child_plan: InProgressPhysicalPlan[PartitionT], projection: ExpressionsProjection, - partial_stateful_udfs: dict[str, PartialStatefulUDF], resource_request: execution_step.ResourceRequest, num_actors: int, ) -> InProgressPhysicalPlan[PartitionT]: - raise NotImplementedError("Execution of ActorPoolProjects not yet implemented") + stage_id = next(stage_id_counter) + actor_pool_name = f"ActorPool_stage{stage_id}" + + # Keep track of materializations of the children tasks + child_materializations_buffer_len = num_actors * 2 + child_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + # Keep track of materializations of the actor_pool tasks + actor_pool_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + with get_context().runner().actor_pool_context( + actor_pool_name, + resource_request, + num_actors, + projection, + ) as actor_pool_id: + child_plan_exhausted = False + + # Loop until the child plan is exhausted and there is no more work in the pipeline + while not (child_plan_exhausted and len(child_materializations) == 0 and len(actor_pool_materializations) == 0): + # Exhaustively pop ready child_steps and submit them to be run on the actor_pool + while len(child_materializations) > 0 and child_materializations[0].done(): + next_ready_child = child_materializations.popleft() + actor_project_step = ( + PartitionTaskBuilder[PartitionT]( + inputs=[next_ready_child.partition()], + partial_metadatas=[next_ready_child.partition_metadata()], + resource_request=resource_request, + actor_pool_id=actor_pool_id, + ) + .add_instruction( + instruction=execution_step.StatefulUDFProject(projection), + ) + .finalize_partition_task_single_output( + stage_id=stage_id, + ) + ) + actor_pool_materializations.append(actor_project_step) + yield actor_project_step + + # Exhaustively pop ready actor_pool steps and bubble it upwards as the start of a new pipeline + while len(actor_pool_materializations) > 0 and actor_pool_materializations[0].done(): + next_ready_actor_pool_task = actor_pool_materializations.popleft() + new_pipeline_starter_task = PartitionTaskBuilder[PartitionT]( + inputs=[next_ready_actor_pool_task.partition()], + partial_metadatas=[next_ready_actor_pool_task.partition_metadata()], + resource_request=ResourceRequest(), + ) + yield new_pipeline_starter_task + + # No more child work to be done: if there is pending work in the pipeline we yield None + if child_plan_exhausted: + if len(child_materializations) > 0 or len(actor_pool_materializations) > 0: + yield None + + # If there is capacity in the pipeline, attempt to schedule child work + elif len(child_materializations) < child_materializations_buffer_len: + try: + child_step = next(child_plan) + except StopIteration: + child_plan_exhausted = True + else: + # Finalize and yield the child step to be run if it is a PartitionTaskBuilder + if isinstance(child_step, PartitionTaskBuilder): + child_step = child_step.finalize_partition_task_single_output(stage_id=stage_id) + child_materializations.append(child_step) + yield child_step + + # Otherwise, indicate that we need to wait for work to complete + else: + yield None def monotonically_increasing_id( diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index e06bfe187d..833887893c 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -23,8 +23,6 @@ from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties - from daft.udf import PartialStatefulUDF - def scan_with_tasks( scan_tasks: list[ScanTask], @@ -83,7 +81,6 @@ def project( def actor_pool_project( input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], - partial_stateful_udfs: dict[str, PartialStatefulUDF], resource_request: ResourceRequest | None, num_actors: int, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: @@ -94,7 +91,6 @@ def actor_pool_project( return physical_plan.actor_pool_project( child_plan=input, projection=expr_projection, - partial_stateful_udfs=partial_stateful_udfs, resource_request=resource_request, num_actors=num_actors, ) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 8de584035b..94fbd88b2e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -20,7 +20,7 @@ import daft.daft as native from daft import context -from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest +from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, bind_stateful_udfs from daft.daft import PyExpr as _PyExpr from daft.daft import col as _col from daft.daft import date_lit as _date_lit @@ -1099,6 +1099,9 @@ def __reduce__(self) -> tuple: def _input_mapping(self) -> builtins.str | None: return self._expr._input_mapping() + def _bind_stateful_udfs(self, initialized_funcs: dict[builtins.str, Callable]) -> Expression: + return Expression._from_pyexpr(bind_stateful_udfs(self._expr, initialized_funcs)) + SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace") diff --git a/daft/pickle/cloudpickle.py b/daft/pickle/cloudpickle.py index 7a92d8977c..872e2edd4e 100644 --- a/daft/pickle/cloudpickle.py +++ b/daft/pickle/cloudpickle.py @@ -1,3 +1,4 @@ +# type: ignore """ Taken from: https://github.com/cloudpipe/cloudpickle/blob/master/cloudpickle/cloudpickle.py diff --git a/daft/pickle/cloudpickle_fast.py b/daft/pickle/cloudpickle_fast.py index 00c66f3f97..d1eafe8d7f 100644 --- a/daft/pickle/cloudpickle_fast.py +++ b/daft/pickle/cloudpickle_fast.py @@ -1,3 +1,4 @@ +# type: ignore """ Code from: https://github.com/cloudpipe/cloudpickle/blob/master/cloudpickle/cloudpickle_fast.py diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 0820af6987..0559e6c501 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import logging import threading from concurrent import futures @@ -11,6 +12,7 @@ from daft.execution import physical_plan from daft.execution.execution_step import Instruction, PartitionTask from daft.execution.native_executor import NativeExecutor +from daft.expressions import ExpressionsProjection from daft.filesystem import glob_path_with_stats from daft.internal.gpu import cuda_device_count from daft.logical.builder import LogicalPlanBuilder @@ -27,6 +29,7 @@ from daft.runners.progress_bar import ProgressBar from daft.runners.runner import Runner from daft.table import MicroPartition +from daft.udf import UserProvidedPythonFunction logger = logging.getLogger(__name__) @@ -94,6 +97,100 @@ def wait(self) -> None: pass +class PyActorPool: + initialized_stateful_udfs_process_singleton: dict[str, UserProvidedPythonFunction] | None = None + + def __init__( + self, + pool_id: str, + num_actors: int, + resource_request: ResourceRequest, + projection: ExpressionsProjection, + ): + self._pool_id = pool_id + self._num_actors = num_actors + self._resource_request = resource_request + self._executor: futures.ProcessPoolExecutor | None = None + self._projection = projection + + @staticmethod + def initialize_actor_global_state(uninitialized_projection: ExpressionsProjection): + from daft.daft import extract_partial_stateful_udf_py + + if PyActorPool.initialized_stateful_udfs_process_singleton is not None: + raise RuntimeError("Cannot initialize Python process actor twice.") + else: + partial_stateful_udfs = { + name: psu + for expr in uninitialized_projection + for name, psu in extract_partial_stateful_udf_py(expr._expr).items() + } + + logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) + + PyActorPool.initialized_stateful_udfs_process_singleton = { + name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() + } + + @staticmethod + def build_partitions_with_stateful_project( + uninitialized_projection: ExpressionsProjection, + partition: MicroPartition, + partial_metadata: PartialPartitionMetadata, + ) -> list[MaterializedResult[MicroPartition]]: + # Bind the expressions to the initialized stateful UDFs, which should already have been initialized at process start-up + initialized_stateful_udfs = PyActorPool.initialized_stateful_udfs_process_singleton + assert ( + initialized_stateful_udfs is not None + ), "PyActor process must be initialized with stateful UDFs before execution" + initialized_projection = ExpressionsProjection( + [e._bind_stateful_udfs(initialized_stateful_udfs) for e in uninitialized_projection] + ) + new_part = partition.eval_expression_list(initialized_projection) + return [ + PyMaterializedResult(new_part, PartitionMetadata.from_table(new_part).merge_with_partial(partial_metadata)) + ] + + def submit( + self, + instruction_stack: list[Instruction], + partitions: list[MicroPartition], + final_metadata: list[PartialPartitionMetadata], + ) -> futures.Future[list[MaterializedResult[MicroPartition]]]: + from daft.execution import execution_step + + assert self._executor is not None, "Cannot submit to uninitialized PyActorPool" + + # PyActorPools can only handle 1 to 1 projections (no fanouts/fan-ins) and only + # StatefulUDFProject instructions (no filters etc) + assert len(partitions) == 1 + assert len(final_metadata) == 1 + assert len(instruction_stack) == 1 + instruction = instruction_stack[0] + assert isinstance(instruction, execution_step.StatefulUDFProject) + projection = instruction.projection + partition = partitions[0] + partial_metadata = final_metadata[0] + + return self._executor.submit( + PyActorPool.build_partitions_with_stateful_project, + projection, + partition, + partial_metadata, + ) + + def teardown(self) -> None: + # Shut down the executor + assert self._executor is not None, "Should have an executor when exiting context" + self._executor.shutdown() + self._executor = None + + def setup(self) -> None: + self._executor = futures.ProcessPoolExecutor( + self._num_actors, initializer=PyActorPool.initialize_actor_global_state, initargs=(self._projection,) + ) + + class PyRunnerIO(runner_io.RunnerIO): def glob_paths_details( self, @@ -121,6 +218,9 @@ def __init__(self, use_thread_pool: bool | None) -> None: self._use_thread_pool: bool = use_thread_pool if use_thread_pool is not None else True self._thread_pool = futures.ThreadPoolExecutor() + # Registry of active ActorPools + self._actor_pools: dict[str, PyActorPool] = {} + # Global accounting of tasks and resources self._inflight_futures: dict[str, futures.Future] = {} @@ -216,6 +316,35 @@ def run_iter_tables( for result in self.run_iter(builder, results_buffer_size=results_buffer_size): yield result.partition() + @contextlib.contextmanager + def actor_pool_context( + self, name: str, resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection + ) -> Iterator[str]: + actor_pool_id = f"py_actor_pool-{name}" + + total_resource_request = resource_request * num_actors + admitted = self._attempt_admit_task(total_resource_request) + + if not admitted: + raise RuntimeError( + f"Not enough resources available to admit {num_actors} actors, each with resource request: {resource_request}" + ) + + try: + self._actor_pools[actor_pool_id] = PyActorPool(actor_pool_id, num_actors, resource_request, projection) + self._actor_pools[actor_pool_id].setup() + logger.debug("Created actor pool %s with resources: %s", actor_pool_id, total_resource_request) + yield actor_pool_id + # NOTE: Ensure that teardown always occurs regardless of any errors that occur during actor pool setup or execution + finally: + logger.debug("Tearing down actor pool: %s", actor_pool_id) + with self._resource_accounting_lock: + self._available_bytes_memory += total_resource_request.memory_bytes or 0 + self._available_cpus += total_resource_request.num_cpus or 0.0 + self._available_gpus += total_resource_request.num_gpus or 0.0 + self._actor_pools[actor_pool_id].teardown() + del self._actor_pools[actor_pool_id] + def _physical_plan_to_partitions( self, plan: physical_plan.MaterializedPhysicalPlan[MicroPartition] ) -> Iterator[PyMaterializedResult]: @@ -286,13 +415,24 @@ def _physical_plan_to_partitions( # update progress bar pbar.mark_task_start(next_step) - future = self._thread_pool.submit( - self.build_partitions, - next_step.instructions, - next_step.inputs, - next_step.partial_metadatas, - next_step.resource_request, - ) + if next_step.actor_pool_id is None: + future = self._thread_pool.submit( + self.build_partitions, + next_step.instructions, + next_step.inputs, + next_step.partial_metadatas, + next_step.resource_request, + ) + else: + actor_pool = self._actor_pools.get(next_step.actor_pool_id) + assert ( + actor_pool is not None + ), f"PyActorPool={next_step.actor_pool_id} must outlive the tasks that need to be run on it." + future = actor_pool.submit( + next_step.instructions, + next_step.inputs, + next_step.partial_metadatas, + ) # Register the inflight task assert ( diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 935359bed5..9bd6055f33 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import logging import threading import time @@ -11,6 +12,7 @@ import pyarrow as pa from daft.context import get_context, set_execution_config +from daft.expressions import ExpressionsProjection from daft.logical.builder import LogicalPlanBuilder from daft.plan_scheduler import PhysicalPlanScheduler from daft.runners.progress_bar import ProgressBar @@ -875,6 +877,12 @@ def run_iter_tables( for result in self.run_iter(builder, results_buffer_size=results_buffer_size): yield ray.get(result.partition()) + @contextlib.contextmanager + def actor_pool_context( + self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection + ) -> Iterator[str]: + raise NotImplementedError("Actor pool for RayRunner not yet implemented") + def _collect_into_cache(self, results_iter: Iterator[RayMaterializedResult]) -> PartitionCacheEntry: result_pset = RayPartitionSet() diff --git a/daft/runners/runner.py b/daft/runners/runner.py index 98bcdff7b9..0e864e6469 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -1,8 +1,11 @@ from __future__ import annotations +import contextlib from abc import abstractmethod from typing import Generic, Iterator +from daft.daft import ResourceRequest +from daft.expressions import ExpressionsProjection from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import ( MaterializedResult, @@ -56,3 +59,26 @@ def run_iter_tables( that can be buffered before execution should pause and wait. """ ... + + @abstractmethod + @contextlib.contextmanager + def actor_pool_context( + self, + name: str, + resource_request: ResourceRequest, + num_actors: int, + projection: ExpressionsProjection, + ) -> Iterator[str]: + """Creates a pool of actors which can execute work, and yield a context in which the pool can be used. + + Also yields a `str` ID which clients can use to refer to the actor pool when submitting tasks. + + Note that attempting to do work outside this context will result in errors! + + Args: + name: Name of the actor pool for debugging/observability + resource_request: Requested amount of resources for each actor + num_actors: Number of actors to spin up + partial_stateful_udf: A stateful UDF that has been "bound" to its arguments, so each actor can run it + """ + ... diff --git a/src/common/resource-request/src/lib.rs b/src/common/resource-request/src/lib.rs index 10d8801a78..1baa51d745 100644 --- a/src/common/resource-request/src/lib.rs +++ b/src/common/resource-request/src/lib.rs @@ -105,6 +105,14 @@ impl ResourceRequest { .iter() .fold(Default::default(), |acc, e| acc.max(e.as_ref())) } + + pub fn multiply(&self, factor: f64) -> Self { + Self::new_internal( + self.num_cpus.map(|x| x * factor), + self.num_gpus.map(|x| x * factor), + self.memory_bytes.map(|x| x * (factor as usize)), + ) + } } impl Add for &ResourceRequest { @@ -207,6 +215,10 @@ impl ResourceRequest { self + other } + fn __mul__(&self, factor: f64) -> Self { + self.multiply(factor) + } + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { match op { CompareOp::Eq => self == other, diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index b5bb977dec..0cb5b1876c 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -1,10 +1,12 @@ #[cfg(feature = "python")] mod pyobj_serde; mod udf; +#[cfg(feature = "python")] +mod udf_runtime_binding; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; @@ -54,6 +56,8 @@ pub struct StatefulPythonUDF { pub init_args: Option, pub batch_size: Option, pub concurrency: Option, + #[cfg(feature = "python")] + pub runtime_binding: udf_runtime_binding::UDFRuntimeBinding, } #[cfg(feature = "python")] @@ -120,6 +124,7 @@ pub fn stateful_udf( init_args: init_args.map(pyobj_serde::PyObjectWrapper), batch_size, concurrency, + runtime_binding: udf_runtime_binding::UDFRuntimeBinding::Unbound, })), inputs: expressions.into(), }) @@ -219,3 +224,59 @@ pub fn get_concurrency(exprs: &[ExprRef]) -> usize { } projection_concurrency.expect("get_concurrency expects one StatefulUDF") } + +/// Binds every StatefulPythonUDF expression to an initialized function provided by an actor +#[cfg(feature = "python")] +pub fn bind_stateful_udfs( + expr: ExprRef, + initialized_funcs: &HashMap>, +) -> DaftResult { + expr.transform(|e| match e.as_ref() { + Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(stateful_py_udf)), + inputs, + } => { + let f = initialized_funcs + .get(stateful_py_udf.name.as_ref()) + .ok_or_else(|| { + DaftError::InternalError(format!( + "Unable to find UDF to bind: {}", + stateful_py_udf.name.as_ref() + )) + })?; + let bound_expr = Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + runtime_binding: udf_runtime_binding::UDFRuntimeBinding::Bound(f.clone()), + ..stateful_py_udf.clone() + })), + inputs: inputs.clone(), + }; + Ok(common_treenode::Transformed::yes(bound_expr.into())) + } + _ => Ok(common_treenode::Transformed::no(e)), + }) + .map(|transformed| transformed.data) +} + +/// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree +#[cfg(feature = "python")] +pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap> { + let mut py_partial_udfs = HashMap::new(); + expr.apply(|child| { + if let Expr::Function { + func: + FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + name, + stateful_partial_func: py_partial_udf, + .. + })), + .. + } = child.as_ref() + { + py_partial_udfs.insert(name.as_ref().to_string(), py_partial_udf.0.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + py_partial_udfs +} diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 282083bab7..edaae70889 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -160,6 +160,7 @@ impl FunctionEvaluator for StatefulPythonUDF { #[cfg(feature = "python")] fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + use crate::functions::python::udf_runtime_binding::UDFRuntimeBinding; use pyo3::{ types::{PyDict, PyTuple}, Python, @@ -173,51 +174,68 @@ impl FunctionEvaluator for StatefulPythonUDF { ))); } - Python::with_gil(|py| { - // Extract the required Python objects to call our run_udf helper - let func = self - .stateful_partial_func - .0 - .getattr(py, pyo3::intern!(py, "func_cls"))?; - let bound_args = self - .stateful_partial_func - .0 - .getattr(py, pyo3::intern!(py, "bound_args"))?; - - // HACK: This is the naive initialization of the class. It is performed once-per-evaluate which is not ideal. - // Ideally we need to allow evaluate to somehow take in the **initialized** Python class that is provided by the Actor. - // Either that, or the code-path to evaluate a StatefulUDF should bypass `evaluate` entirely and do its own thing. - let func = match &self.init_args { - None => func.call0(py)?, - Some(init_args) => { - let init_args = init_args - .0 - .as_ref(py) - .downcast::() - .expect("init_args should be a Python tuple"); - let (args, kwargs) = ( - init_args - .get_item(0)? + if let UDFRuntimeBinding::Bound(func) = &self.runtime_binding { + Python::with_gil(|py| { + // Extract the required Python objects to call our run_udf helper + let bound_args = self + .stateful_partial_func + .0 + .getattr(py, pyo3::intern!(py, "bound_args"))?; + run_udf( + py, + inputs, + pyo3::Py::clone_ref(func, py), + bound_args, + &self.return_dtype, + self.batch_size, + ) + }) + } else { + // NOTE: This branch of evaluation performs a naive initialization of the class. It is performed once-per-evaluate + // which is not ideal. Callers trying to .evaluate a StatefulPythonUDF should first bind it to initialized classes. + Python::with_gil(|py| { + // Extract the required Python objects to call our run_udf helper + let func = self + .stateful_partial_func + .0 + .getattr(py, pyo3::intern!(py, "func_cls"))?; + let bound_args = self + .stateful_partial_func + .0 + .getattr(py, pyo3::intern!(py, "bound_args"))?; + + let func = match &self.init_args { + None => func.call0(py)?, + Some(init_args) => { + let init_args = init_args + .0 + .as_ref(py) .downcast::() - .expect("init_args[0] should be a tuple of *args"), - init_args - .get_item(1)? - .downcast::() - .expect("init_args[1] should be a dict of **kwargs"), - ); - func.call(py, args, Some(kwargs))? - } - }; - - run_udf( - py, - inputs, - func, - bound_args, - &self.return_dtype, - self.batch_size, - ) - }) + .expect("init_args should be a Python tuple"); + let (args, kwargs) = ( + init_args + .get_item(0)? + .downcast::() + .expect("init_args[0] should be a tuple of *args"), + init_args + .get_item(1)? + .downcast::() + .expect("init_args[1] should be a dict of **kwargs"), + ); + func.call(py, args, Some(kwargs))? + } + }; + + run_udf( + py, + inputs, + func, + bound_args, + &self.return_dtype, + self.batch_size, + ) + }) + } } #[cfg(not(feature = "python"))] diff --git a/src/daft-dsl/src/functions/python/udf_runtime_binding.rs b/src/daft-dsl/src/functions/python/udf_runtime_binding.rs new file mode 100644 index 0000000000..6ef81faddd --- /dev/null +++ b/src/daft-dsl/src/functions/python/udf_runtime_binding.rs @@ -0,0 +1,70 @@ +use std::hash::{Hash, Hasher}; + +use serde::{de::Visitor, Deserialize, Serialize}; + +/// A binding between the StatefulPythonUDF and an initialized Python callable +/// +/// This is `Unbound` during planning, and bound to an initialized Python callable +/// by an Actor right before execution. +/// +/// Note that attempting to Hash, Eq, Serde this when it is bound will panic! +#[derive(Debug, Clone)] +pub enum UDFRuntimeBinding { + Unbound, + Bound(pyo3::PyObject), +} + +impl PartialEq for UDFRuntimeBinding { + fn eq(&self, other: &Self) -> bool { + matches!((self, other), (Self::Unbound, Self::Unbound)) + } +} + +impl Eq for UDFRuntimeBinding {} + +impl Hash for UDFRuntimeBinding { + fn hash(&self, state: &mut H) { + match self { + Self::Unbound => state.write_u8(0), + Self::Bound(_) => panic!("Cannot hash a bound UDFRuntimeBinding."), + } + } +} + +impl Serialize for UDFRuntimeBinding { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Unbound => serializer.serialize_unit(), + Self::Bound(_) => panic!("Cannot serialize a bound UDFRuntimeBinding."), + } + } +} + +struct UDFRuntimeBindingVisitor; + +impl<'de> Visitor<'de> for UDFRuntimeBindingVisitor { + type Value = UDFRuntimeBinding; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("unit") + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(UDFRuntimeBinding::Unbound) + } +} + +impl<'de> Deserialize<'de> for UDFRuntimeBinding { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_unit(UDFRuntimeBindingVisitor) + } +} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 557cf3a160..4ac9c422e9 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -38,6 +38,8 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_wrapped(wrap_pyfunction!(python::series_lit))?; parent.add_wrapped(wrap_pyfunction!(python::stateless_udf))?; parent.add_wrapped(wrap_pyfunction!(python::stateful_udf))?; + parent.add_wrapped(wrap_pyfunction!(python::extract_partial_stateful_udf_py))?; + parent.add_wrapped(wrap_pyfunction!(python::bind_stateful_udfs))?; parent.add_wrapped(wrap_pyfunction!(python::eq))?; parent.add_wrapped(wrap_pyfunction!(python::resolve_expr))?; diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 4217ed68c0..032fa10dfc 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -1,5 +1,6 @@ use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -231,6 +232,23 @@ pub fn stateful_udf( }) } +/// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree +#[pyfunction] +pub fn extract_partial_stateful_udf_py(expr: PyExpr) -> HashMap> { + use crate::functions::python::extract_partial_stateful_udf_py; + extract_partial_stateful_udf_py(expr.expr) +} + +/// Binds the StatefulPythonUDFs in a given expression to any corresponding initialized Python callables in the provided map +#[pyfunction] +pub fn bind_stateful_udfs( + expr: PyExpr, + initialized_funcs: HashMap>, +) -> PyResult { + use crate::functions::python::bind_stateful_udfs; + Ok(bind_stateful_udfs(expr.expr, &initialized_funcs).map(PyExpr::from)?) +} + #[pyclass(module = "daft.daft")] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PyExpr { diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index fd809c97a6..5b8c9b13a3 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -317,35 +317,6 @@ fn physical_plan_to_partition_tasks( input, projection, .. }, ) => { - use daft_dsl::{ - common_treenode::TreeNode, - functions::{ - python::{PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, - }; - - // Extract any StatefulUDFs from the projection - let mut py_partial_udfs = HashMap::new(); - projection.iter().for_each(|e| { - e.apply(|child| { - if let Expr::Function { - func: - FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - name, - stateful_partial_func: py_partial_udf, - .. - })), - .. - } = child.as_ref() - { - py_partial_udfs.insert(name.as_ref().to_string(), py_partial_udf.0.clone()); - } - Ok(daft_dsl::common_treenode::TreeNodeRecursion::Continue) - }) - .unwrap(); - }); - let upstream_iter = physical_plan_to_partition_tasks(input, py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? @@ -356,8 +327,8 @@ fn physical_plan_to_partition_tasks( .iter() .map(|expr| PyExpr::from(expr.clone())) .collect::>(), - py_partial_udfs, app.resource_request(), + app.concurrency(), ))?; Ok(py_iter.into()) } diff --git a/tests/actor_pool/__init__.py b/tests/actor_pool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/actor_pool/test_pyactor_pool.py b/tests/actor_pool/test_pyactor_pool.py new file mode 100644 index 0000000000..0f410549c6 --- /dev/null +++ b/tests/actor_pool/test_pyactor_pool.py @@ -0,0 +1,74 @@ +import multiprocessing +from concurrent.futures import wait + +import pytest + +import daft +from daft import DataType, ResourceRequest +from daft.context import get_context +from daft.execution.execution_step import StatefulUDFProject +from daft.expressions import ExpressionsProjection +from daft.runners.partitioning import PartialPartitionMetadata +from daft.runners.pyrunner import PyActorPool, PyRunner +from daft.table import MicroPartition + + +@daft.udf(return_dtype=DataType.int64()) +class MyStatefulUDF: + def __init__(self): + self.state = 0 + + def __call__(self, x): + self.state += 1 + return [i + self.state for i in x.to_pylist()] + + +def test_pyactor_pool(): + projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + pool = PyActorPool("my-pool", 1, ResourceRequest(num_cpus=1), projection) + initial_partition = MicroPartition.from_pydict({"x": [1, 1, 1]}) + ppm = PartialPartitionMetadata(num_rows=None, size_bytes=None) + instr = StatefulUDFProject(projection=projection) + + pool_id = pool.setup() + assert pool_id == "my-pool" + + result = pool.submit( + instruction_stack=[instr], + partitions=[initial_partition], + final_metadata=[ppm], + ) + done, _ = wait([result], timeout=None) + result_data = list(done)[0].result()[0] + assert result_data.partition().to_pydict() == {"x": [2, 2, 2]} + + result = pool.submit( + instruction_stack=[instr], + partitions=[initial_partition], + final_metadata=[ppm], + ) + done, _ = wait([result], timeout=None) + result_data = list(done)[0].result()[0] + assert result_data.partition().to_pydict() == {"x": [3, 3, 3]} + + result = pool.submit( + instruction_stack=[instr], + partitions=[initial_partition], + final_metadata=[ppm], + ) + done, _ = wait([result], timeout=None) + result_data = list(done)[0].result()[0] + assert result_data.partition().to_pydict() == {"x": [4, 4, 4]} + + +@pytest.mark.skipif(get_context().runner_config.name != "py", reason="Test can only be run on PyRunner") +def test_pyactor_pool_not_enough_resources(): + cpu_count = multiprocessing.cpu_count() + projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + + runner = get_context().runner() + assert isinstance(runner, PyRunner) + + with pytest.raises(RuntimeError, match=f"Requested {float(cpu_count + 1)} CPUs but found only"): + with runner.actor_pool_context("my-pool", ResourceRequest(num_cpus=1), cpu_count + 1, projection) as _: + pass