diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index dd77d8148e..a5e4ad0844 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator from daft.plan_scheduler.physical_plan_scheduler import PartitionT from daft.runners.partitioning import PartitionCacheEntry from daft.sql.sql_connection import SQLConnection -from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF +from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf if TYPE_CHECKING: import pyarrow as pa @@ -1123,29 +1123,20 @@ def interval_lit( ) -> PyExpr: ... def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... -def stateless_udf( +def udf( name: str, - partial_stateless_udf: PartialStatelessUDF, + inner: UninitializedUdf, + bound_args: BoundUDFArgs, expressions: list[PyExpr], return_dtype: PyDataType, - resource_request: ResourceRequest | None, - batch_size: int | None, -) -> PyExpr: ... -def stateful_udf( - name: str, - partial_stateful_udf: PartialStatefulUDF, - expressions: list[PyExpr], - return_dtype: PyDataType, - resource_request: ResourceRequest | None, init_args: InitArgsType, + resource_request: ResourceRequest | None, batch_size: int | None, concurrency: int | None, ) -> PyExpr: ... def check_column_name_validity(name: str, schema: PySchema): ... -def extract_partial_stateful_udf_py( - expression: PyExpr, -) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ... -def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ... +def initialize_udfs(expression: PyExpr) -> PyExpr: ... +def get_udf_names(expression: PyExpr) -> list[str]: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ... @@ -1885,12 +1876,9 @@ class PyDaftPlanningConfig: def with_config_values( self, default_io_config: IOConfig | None = None, - enable_actor_pool_projections: bool | None = None, ) -> PyDaftPlanningConfig: ... @property def default_io_config(self) -> IOConfig: ... - @property - def enable_actor_pool_projections(self) -> bool: ... def build_type() -> str: ... def version() -> str: ... diff --git a/daft/execution/actor_pool_udf.py b/daft/execution/actor_pool_udf.py new file mode 100644 index 0000000000..a60e2da7ed --- /dev/null +++ b/daft/execution/actor_pool_udf.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import logging +import multiprocessing as mp +from typing import TYPE_CHECKING + +from daft.expressions import Expression, ExpressionsProjection +from daft.table import MicroPartition + +if TYPE_CHECKING: + from multiprocessing.connection import Connection + + from daft.daft import PyExpr, PyMicroPartition + +logger = logging.getLogger(__name__) + + +def actor_event_loop(uninitialized_projection: ExpressionsProjection, conn: Connection) -> None: + """ + Event loop that runs in a actor process and receives MicroPartitions to evaluate with an initialized UDF projection. + + Terminates once it receives None. + """ + initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection]) + + while True: + input: MicroPartition | None = conn.recv() + if input is None: + break + + output = input.eval_expression_list(initialized_projection) + conn.send(output) + + +class ActorHandle: + """Handle class for initializing, interacting with, and tearing down a single local actor process.""" + + def __init__(self, projection: list[PyExpr]) -> None: + self.handle_conn, actor_conn = mp.Pipe() + + expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection]) + self.actor_process = mp.Process(target=actor_event_loop, args=(expr_projection, actor_conn)) + self.actor_process.start() + + def eval_input(self, input: PyMicroPartition) -> PyMicroPartition: + self.handle_conn.send(MicroPartition._from_pymicropartition(input)) + output: MicroPartition = self.handle_conn.recv() + return output._micropartition + + def teardown(self) -> None: + self.handle_conn.send(None) + self.handle_conn.close() + self.actor_process.join() diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 94873a4bb4..a1f6dea909 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -553,7 +553,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) @dataclass(frozen=True) -class StatefulUDFProject(SingleOutputInstruction): +class ActorPoolProject(SingleOutputInstruction): projection: ExpressionsProjection def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: @@ -564,7 +564,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) PartialPartitionMetadata( num_rows=None, # UDFs can potentially change cardinality size_bytes=None, - boundaries=None, # TODO: figure out if the stateful UDF projection changes boundaries + boundaries=None, # TODO: figure out if the actor pool UDF projection changes boundaries ) ] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 55fa3f1f03..0723665e96 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -229,7 +229,7 @@ def actor_pool_context( name: Name of the actor pool for debugging/observability resource_request: Requested amount of resources for each actor num_actors: Number of actors to spin up - projection: Projection to be run on the incoming data (contains Stateful UDFs as well as other stateless expressions such as aliases) + projection: Projection to be run on the incoming data (contains actor pool UDFs as well as other stateless expressions such as aliases) """ ... @@ -243,12 +243,10 @@ def actor_pool_project( ) -> InProgressPhysicalPlan[PartitionT]: stage_id = next(stage_id_counter) - from daft.daft import extract_partial_stateful_udf_py + from daft.daft import get_udf_names - stateful_udf_names = "-".join( - name for expr in projection for name in extract_partial_stateful_udf_py(expr._expr).keys() - ) - actor_pool_name = f"{stateful_udf_names}-stage={stage_id}" + udf_names = "-".join(name for expr in projection for name in get_udf_names(expr._expr)) + actor_pool_name = f"{udf_names}-stage={stage_id}" # Keep track of materializations of the children tasks child_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() @@ -285,7 +283,7 @@ def actor_pool_project( actor_pool_id=actor_pool_id, ) .add_instruction( - instruction=execution_step.StatefulUDFProject(projection), + instruction=execution_step.ActorPoolProject(projection), resource_request=task_resource_request, ) .finalize_partition_task_single_output( diff --git a/daft/execution/stateful_actor.py b/daft/execution/stateful_actor.py deleted file mode 100644 index 8aad591550..0000000000 --- a/daft/execution/stateful_actor.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import logging -import multiprocessing as mp -from typing import TYPE_CHECKING - -from daft.expressions import Expression, ExpressionsProjection -from daft.table import MicroPartition - -if TYPE_CHECKING: - from multiprocessing.connection import Connection - - from daft.daft import PyExpr, PyMicroPartition - -logger = logging.getLogger(__name__) - - -def initialize_actor_pool_projection(projection: ExpressionsProjection) -> ExpressionsProjection: - """Initializes the stateful UDFs in the projection.""" - - from daft.daft import extract_partial_stateful_udf_py - - partial_stateful_udfs = { - name: psu for expr in projection for name, psu in extract_partial_stateful_udf_py(expr._expr).items() - } - - logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - - initialized_stateful_udfs = {} - for name, (partial_udf, init_args) in partial_stateful_udfs.items(): - if init_args is None: - initialized_stateful_udfs[name] = partial_udf.func_cls() - else: - args, kwargs = init_args - initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs) - - initialized_projection = ExpressionsProjection( - [e._bind_stateful_udfs(initialized_stateful_udfs) for e in projection] - ) - - return initialized_projection - - -def stateful_actor_event_loop(uninitialized_projection: ExpressionsProjection, conn: Connection) -> None: - """ - Event loop that runs in a stateful actor process and receives MicroPartitions to evaluate with a stateful UDF. - - Terminates once it receives None. - """ - initialized_projection = initialize_actor_pool_projection(uninitialized_projection) - - while True: - input: MicroPartition | None = conn.recv() - if input is None: - break - - output = input.eval_expression_list(initialized_projection) - conn.send(output) - - -class StatefulActorHandle: - """Handle class for initializing, interacting with, and tearing down a single local stateful actor process.""" - - def __init__(self, projection: list[PyExpr]) -> None: - self.handle_conn, actor_conn = mp.Pipe() - - expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection]) - self.actor_process = mp.Process(target=stateful_actor_event_loop, args=(expr_projection, actor_conn)) - self.actor_process.start() - - def eval_input(self, input: PyMicroPartition) -> PyMicroPartition: - self.handle_conn.send(MicroPartition._from_pymicropartition(input)) - output: MicroPartition = self.handle_conn.recv() - return output._micropartition - - def teardown(self) -> None: - self.handle_conn.send(None) - self.handle_conn.close() - self.actor_process.join() diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index ab125958ab..baa698aaf4 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -19,7 +19,7 @@ import daft.daft as native from daft import context -from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, bind_stateful_udfs +from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, initialize_udfs from daft.daft import PyExpr as _PyExpr from daft.daft import col as _col from daft.daft import date_lit as _date_lit @@ -28,13 +28,12 @@ from daft.daft import list_sort as _list_sort from daft.daft import lit as _lit from daft.daft import series_lit as _series_lit -from daft.daft import stateful_udf as _stateful_udf -from daft.daft import stateless_udf as _stateless_udf from daft.daft import time_lit as _time_lit from daft.daft import timestamp_lit as _timestamp_lit from daft.daft import to_struct as _to_struct from daft.daft import tokenize_decode as _tokenize_decode from daft.daft import tokenize_encode as _tokenize_encode +from daft.daft import udf as _udf from daft.daft import url_download as _url_download from daft.daft import utf8_count_matches as _utf8_count_matches from daft.datatype import DataType, TimeUnit @@ -45,7 +44,7 @@ if TYPE_CHECKING: from daft.io import IOConfig - from daft.udf import PartialStatefulUDF, PartialStatelessUDF + from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf # This allows Sphinx to correctly work against our "namespaced" accessor functions by overriding @property to # return a class instance of the namespace instead of a property object. elif os.getenv("DAFT_SPHINX_BUILD") == "1": @@ -260,39 +259,26 @@ def _to_expression(obj: object) -> Expression: return lit(obj) @staticmethod - def stateless_udf( + def udf( name: builtins.str, - partial: PartialStatelessUDF, + inner: UninitializedUdf, + bound_args: BoundUDFArgs, expressions: builtins.list[Expression], return_dtype: DataType, + init_args: InitArgsType, resource_request: ResourceRequest | None, batch_size: int | None, - ) -> Expression: - return Expression._from_pyexpr( - _stateless_udf( - name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, batch_size - ) - ) - - @staticmethod - def stateful_udf( - name: builtins.str, - partial: PartialStatefulUDF, - expressions: builtins.list[Expression], - return_dtype: DataType, - resource_request: ResourceRequest | None, - init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None, - batch_size: int | None, concurrency: int | None, ) -> Expression: return Expression._from_pyexpr( - _stateful_udf( + _udf( name, - partial, + inner, + bound_args, [e._expr for e in expressions], return_dtype._dtype, - resource_request, init_args, + resource_request, batch_size, concurrency, ) @@ -1018,7 +1004,7 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression: Returns: Expression: New expression after having run the function on the expression """ - from daft.udf import CommonUDFArgs, StatelessUDF + from daft.udf import UDF def batch_func(self_series): return [func(x) for x in self_series.to_pylist()] @@ -1028,14 +1014,10 @@ def batch_func(self_series): name = name + "." name = name + getattr(func, "__qualname__") # type: ignore[call-overload] - return StatelessUDF( + return UDF( + inner=batch_func, name=name, - func=batch_func, return_dtype=return_dtype, - common_args=CommonUDFArgs( - resource_request=None, - batch_size=None, - ), )(self) def is_null(self) -> Expression: @@ -1263,8 +1245,8 @@ def __reduce__(self) -> tuple: def _input_mapping(self) -> builtins.str | None: return self._expr._input_mapping() - def _bind_stateful_udfs(self, initialized_funcs: dict[builtins.str, Callable]) -> Expression: - return Expression._from_pyexpr(bind_stateful_udfs(self._expr, initialized_funcs)) + def _initialize_udfs(self) -> Expression: + return Expression._from_pyexpr(initialize_udfs(self._expr)) SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace") diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 2d4d20ff3f..e386d0ea99 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -14,6 +14,7 @@ from daft.daft import FileFormatConfig, FileInfos, IOConfig, ResourceRequest, SystemInfo from daft.execution.native_executor import NativeExecutor from daft.execution.physical_plan import ActorPoolManager +from daft.expressions import ExpressionsProjection from daft.filesystem import glob_path_with_stats from daft.internal.gpu import cuda_visible_devices from daft.runners import runner_io @@ -34,7 +35,6 @@ if TYPE_CHECKING: from daft.execution import physical_plan from daft.execution.execution_step import Instruction, PartitionTask - from daft.expressions import ExpressionsProjection from daft.logical.builder import LogicalPlanBuilder logger = logging.getLogger(__name__) @@ -165,11 +165,11 @@ def release(self, resources: AcquiredResources | list[AcquiredResources]): self.available_resources.gpus[gpu] += amount -class PyStatefulActorSingleton: +class PyActorSingleton: """ - This class stores the singleton `initialized_projection` that is isolated to each Python process. It stores the projection with initialized stateful UDF objects of a single actor. + This class stores the singleton `initialized_projection` that is isolated to each Python process. It stores the projection with initialized actor pool UDF objects of a single actor. - Currently, only one stateful UDF per actor is supported, but we allow multiple here in case we want to support multiple stateful UDFs in the future. + Currently, only one actor pool UDF per actor is supported, but we allow multiple here in case we want to support multiple actor pool UDFs in the future. Note: The class methods should only be called inside of actor processes. """ @@ -181,28 +181,28 @@ def initialize_actor_global_state( uninitialized_projection: ExpressionsProjection, cuda_device_queue: mp.Queue[str], ): - if PyStatefulActorSingleton.initialized_projection is not None: + if PyActorSingleton.initialized_projection is not None: raise RuntimeError("Cannot initialize Python process actor twice.") import os - from daft.execution.stateful_actor import initialize_actor_pool_projection - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device_queue.get(timeout=1) - PyStatefulActorSingleton.initialized_projection = initialize_actor_pool_projection(uninitialized_projection) + PyActorSingleton.initialized_projection = ExpressionsProjection( + [e._initialize_udfs() for e in uninitialized_projection] + ) @staticmethod - def build_partitions_with_stateful_project( + def build_partitions_with_actor_pool_project( partition: MicroPartition, partial_metadata: PartialPartitionMetadata, ) -> list[MaterializedResult[MicroPartition]]: - # Bind the expressions to the initialized stateful UDFs, which should already have been initialized at process start-up + # Bind the expressions to the initialized actor pool UDFs, which should already have been initialized at process start-up assert ( - PyStatefulActorSingleton.initialized_projection is not None - ), "PyActor process must be initialized with stateful UDFs before execution" + PyActorSingleton.initialized_projection is not None + ), "PyActor process must be initialized with actor pool UDFs before execution" - new_part = partition.eval_expression_list(PyStatefulActorSingleton.initialized_projection) + new_part = partition.eval_expression_list(PyActorSingleton.initialized_projection) return [ LocalMaterializedResult( new_part, PartitionMetadata.from_table(new_part).merge_with_partial(partial_metadata) @@ -235,17 +235,17 @@ def submit( assert self._executor is not None, "Cannot submit to uninitialized PyActorPool" # PyActorPools can only handle 1 to 1 projections (no fanouts/fan-ins) and only - # StatefulUDFProject instructions (no filters etc) + # ActorPoolProject instructions (no filters etc) assert len(partitions) == 1 assert len(final_metadata) == 1 assert len(instruction_stack) == 1 instruction = instruction_stack[0] - assert isinstance(instruction, execution_step.StatefulUDFProject) + assert isinstance(instruction, execution_step.ActorPoolProject) partition = partitions[0] partial_metadata = final_metadata[0] return self._executor.submit( - PyStatefulActorSingleton.build_partitions_with_stateful_project, + PyActorSingleton.build_partitions_with_actor_pool_project, partition, partial_metadata, ) @@ -264,7 +264,7 @@ def setup(self) -> None: self._executor = futures.ProcessPoolExecutor( self._num_actors, - initializer=PyStatefulActorSingleton.initialize_actor_global_state, + initializer=PyActorSingleton.initialize_actor_global_state, initargs=(self._projection, cuda_device_queue), ) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 82f648e44b..b852bfa3b0 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -40,10 +40,10 @@ IOConfig, PyDaftExecutionConfig, ResourceRequest, - extract_partial_stateful_udf_py, ) from daft.datatype import DataType from daft.execution.execution_step import ( + ActorPoolProject, FanoutInstruction, Instruction, MultiOutputPartitionTask, @@ -51,7 +51,6 @@ ReduceInstruction, ScanWithTask, SingleOutputPartitionTask, - StatefulUDFProject, ) from daft.execution.physical_plan import ActorPoolManager from daft.expressions import ExpressionsProjection @@ -1062,7 +1061,10 @@ def _build_partitions_on_actor_pool( actor_pool: RayRoundRobinActorPool, ) -> list[ray.ObjectRef]: """Run a PartitionTask on an actor pool and return the resulting list of partitions.""" - [metadatas_ref, *partitions] = actor_pool.submit(task.instructions, task.partial_metadatas, task.inputs) + assert len(task.instructions) == 1, "Actor pool can only handle single ActorPoolProject instructions" + assert isinstance(task.instructions[0], ActorPoolProject) + + [metadatas_ref, *partitions] = actor_pool.submit(task.partial_metadatas, task.inputs) metadatas_accessor = PartitionMetadataAccessor(metadatas_ref) task.set_result( [ @@ -1080,26 +1082,19 @@ def _build_partitions_on_actor_pool( @ray.remote class DaftRayActor: def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_projection: ExpressionsProjection): + from daft.daft import get_udf_names + self.daft_execution_config = daft_execution_config - partial_stateful_udfs = { - name: psu - for expr in uninitialized_projection - for name, psu in extract_partial_stateful_udf_py(expr._expr).items() - } - logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - - self.initialized_stateful_udfs = {} - for name, (partial_udf, init_args) in partial_stateful_udfs.items(): - if init_args is None: - self.initialized_stateful_udfs[name] = partial_udf.func_cls() - else: - args, kwargs = init_args - self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs) + + logger.info( + "Initializing stateful UDFs: %s", + ", ".join(name for expr in uninitialized_projection for name in get_udf_names(expr._expr)), + ) + self.initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection]) @ray.method(num_returns=2) def run( self, - uninitialized_projection: ExpressionsProjection, partial_metadatas: list[PartitionMetadata], *inputs: MicroPartition, ) -> list[list[PartitionMetadata] | MicroPartition]: @@ -1109,11 +1104,7 @@ def run( part = inputs[0] partial = partial_metadatas[0] - # Bind the ExpressionsProjection to the initialized UDFs - initialized_projection = ExpressionsProjection( - [e._bind_stateful_udfs(self.initialized_stateful_udfs) for e in uninitialized_projection] - ) - new_part = part.eval_expression_list(initialized_projection) + new_part = part.eval_expression_list(self.initialized_projection) return [ [PartitionMetadata.from_table(new_part).merge_with_partial(partial)], @@ -1159,24 +1150,15 @@ def teardown(self): self._actors = None del old_actors - def submit( - self, instruction_stack: list[Instruction], partial_metadatas: list[ray.ObjectRef], inputs: list[ray.ObjectRef] - ) -> list[ray.ObjectRef]: + def submit(self, partial_metadatas: list[ray.ObjectRef], inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: assert self._actors is not None, "Must have active Ray actors during submission" - assert ( - len(instruction_stack) == 1 - ), "RayRoundRobinActorPool can only handle single StatefulUDFProject instructions" - instruction = instruction_stack[0] - assert isinstance(instruction, StatefulUDFProject) - projection = instruction.projection - # Determine which actor to schedule on in a round-robin fashion idx = self._task_idx % self._num_actors self._task_idx += 1 actor = self._actors[idx] - return actor.run.remote(projection, partial_metadatas, *inputs) + return actor.run.remote(partial_metadatas, *inputs) class RayRunner(Runner[ray.ObjectRef]): diff --git a/daft/udf.py b/daft/udf.py index fdc238b980..beecbc7f45 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -5,7 +5,6 @@ import inspect from typing import Any, Callable, Dict, Optional, Tuple, Union -from daft.context import get_context from daft.daft import PyDataType, ResourceRequest from daft.datatype import DataType from daft.dependencies import np, pa @@ -13,7 +12,21 @@ from daft.series import PySeries, Series InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] -UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] +UdfReturnType = Union[Series, list, "np.ndarray", "pa.Array", "pa.ChunkedArray"] +UserDefinedPyFunc = Callable[..., UdfReturnType] +UserDefinedPyFuncLike = Union[UserDefinedPyFunc, type] + + +@dataclasses.dataclass(frozen=True) +class UninitializedUdf: + inner: Callable[..., UserDefinedPyFunc] + + def initialize(self, init_args: InitArgsType) -> UserDefinedPyFunc: + if init_args is None: + return self.inner() + else: + args, kwargs = init_args + return self.inner(*args, **kwargs) @dataclasses.dataclass(frozen=True) @@ -64,13 +77,13 @@ def __hash__(self) -> int: # Assumes there is at least one evaluated expression def run_udf( - func: Callable, + func: UserDefinedPyFunc, bound_args: BoundUDFArgs, evaluated_expressions: list[Series], py_return_dtype: PyDataType, batch_size: int | None, ) -> PySeries: - """API to call from Rust code that will call an UDF (initialized, in the case of stateful UDFs) on the inputs""" + """API to call from Rust code that will call an UDF (initialized, in the case of actor pool UDFs) on the inputs""" return_dtype = DataType._from_pydatatype(py_return_dtype) kwarg_keys = list(bound_args.bound_args.kwargs.keys()) arg_keys = bound_args.arg_keys() @@ -146,10 +159,10 @@ def get_args_for_slice(start: int, end: int): # Post-processing of results into a Series of the appropriate dtype if isinstance(results[0], Series): - result_series = Series.concat(results) + result_series = Series.concat(results) # type: ignore return result_series.rename(name).cast(return_dtype)._series elif isinstance(results[0], list): - result_list = [x for res in results for x in res] + result_list = [x for res in results for x in res] # type: ignore if return_dtype == DataType.python(): return Series.from_pylist(result_list, name=name, pyobj="force")._series else: @@ -169,134 +182,73 @@ def get_args_for_slice(start: int, end: int): @dataclasses.dataclass -class CommonUDFArgs: - resource_request: ResourceRequest | None - batch_size: int | None - - def override_options( - self, - *, - num_cpus: float | None = _UnsetMarker, - num_gpus: float | None = _UnsetMarker, - memory_bytes: int | None = _UnsetMarker, - batch_size: int | None = _UnsetMarker, - ) -> CommonUDFArgs: - result = self - - # Any changes to resource request - if not all( - ( - num_cpus is _UnsetMarker, - num_gpus is _UnsetMarker, - memory_bytes is _UnsetMarker, - ) - ): - new_resource_request = ResourceRequest() if self.resource_request is None else self.resource_request - if num_cpus is not _UnsetMarker: - new_resource_request = new_resource_request.with_num_cpus(num_cpus) - if num_gpus is not _UnsetMarker: - new_resource_request = new_resource_request.with_num_gpus(num_gpus) - if memory_bytes is not _UnsetMarker: - new_resource_request = new_resource_request.with_memory_bytes(memory_bytes) - result = dataclasses.replace(result, resource_request=new_resource_request) - - if batch_size is not _UnsetMarker: - result.batch_size = batch_size - - return result - - -@dataclasses.dataclass -class PartialStatelessUDF: - """Partially bound stateless UDF""" - - func: UserProvidedPythonFunction - return_dtype: DataType - bound_args: BoundUDFArgs +class UDF: + """A class produced by applying the `@daft.udf` decorator over a Python function or class. + Calling this class produces a `daft.Expression` that can be used in a DataFrame function. -@dataclasses.dataclass -class PartialStatefulUDF: - """Partially bound stateful UDF""" - - func_cls: Callable[[], UserProvidedPythonFunction] - return_dtype: DataType - bound_args: BoundUDFArgs - - -@dataclasses.dataclass -class StatelessUDF: - """A Stateless UDF is produced by calling `@udf` over a Python function""" + Example: + >>> import daft + >>> @daft.udf(return_dtype=daft.DataType.float64()) + ... def multiply_and_add(x: daft.Series, y: float, z: float): + ... return x.to_arrow().to_numpy() * y + z + >>> + >>> df = daft.from_pydict({"x": [1, 2, 3]}) + >>> df = df.with_column("result", multiply_and_add(df["x"], 2.0, z=1.5)) + >>> df.show() + ╭───────┬─────────╮ + │ x ┆ result │ + │ --- ┆ --- │ + │ Int64 ┆ Float64 │ + ╞═══════╪═════════╡ + │ 1 ┆ 3.5 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 5.5 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 7.5 │ + ╰───────┴─────────╯ + + (Showing first 3 of 3 rows) + """ - common_args: CommonUDFArgs + inner: UserDefinedPyFuncLike name: str - func: UserProvidedPythonFunction return_dtype: DataType + init_args: InitArgsType = None + concurrency: int | None = None + resource_request: ResourceRequest | None = None + batch_size: int | None = None def __post_init__(self): - """Analogous to the @functools.wraps(self.func) pattern - - This will swap out identifiers on `self` to match `self.func`. Most notably, this swaps out - self.__module__ and self.__qualname__, which is used in `__reduce__` during serialization. - """ - functools.update_wrapper(self, self.func) + # Analogous to the @functools.wraps(self.inner) pattern + # This will swap out identifiers on `self` to match `self.inner`. Most notably, this swaps out + # self.__module__ and self.__qualname__, which is used in `__reduce__` during serialization. + functools.update_wrapper(self, self.inner) + + # construct the UninitializedUdf here so that the constructed expressions can maintain equality + if isinstance(self.inner, type): + self.wrapped_inner = UninitializedUdf(self.inner) + else: + self.wrapped_inner = UninitializedUdf(lambda: self.inner) def __call__(self, *args, **kwargs) -> Expression: - """Call the UDF using some input Expressions, producing a new Expression that can be used by a DataFrame. - Args: - *args: Positional arguments to be passed to the UDF. These can be either Expressions or Python values. - **kwargs: Keyword arguments to be passed to the UDF. These can be either Expressions or Python values. + self._validate_init_args() - Returns: - Expression: A new Expression representing the UDF call, which can be used in DataFrame operations. - - .. NOTE:: - When passing arguments to the UDF, you can use a mix of Expressions (e.g., df["column"]) and Python values. - Expressions will be evaluated for each row, while Python values will be passed as-is to the UDF. - - Example: - >>> import daft - >>> @daft.udf(return_dtype=daft.DataType.float64()) - ... def multiply_and_add(x: daft.Series, y: float, z: float): - ... return x.to_arrow().to_numpy() * y + z - >>> - >>> df = daft.from_pydict({"x": [1, 2, 3]}) - >>> df = df.with_column("result", multiply_and_add(df["x"], 2.0, z=1.5)) - >>> df.show() - ╭───────┬─────────╮ - │ x ┆ result │ - │ --- ┆ --- │ - │ Int64 ┆ Float64 │ - ╞═══════╪═════════╡ - │ 1 ┆ 3.5 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ - │ 2 ┆ 5.5 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ - │ 3 ┆ 7.5 │ - ╰───────┴─────────╯ - - (Showing first 3 of 3 rows) - """ - bound_args = BoundUDFArgs(self._bind_func(*args, **kwargs)) + bound_args = self._bind_args(*args, **kwargs) expressions = list(bound_args.expressions().values()) - return Expression.stateless_udf( + + return Expression.udf( name=self.name, - partial=PartialStatelessUDF(self.func, self.return_dtype, bound_args), + inner=self.wrapped_inner, + bound_args=bound_args, expressions=expressions, return_dtype=self.return_dtype, - resource_request=self.common_args.resource_request, - batch_size=self.common_args.batch_size, + init_args=self.init_args, + resource_request=self.resource_request, + batch_size=self.batch_size, + concurrency=self.concurrency, ) - def _bind_func(self, *args, **kwargs) -> inspect.BoundArguments: - sig = inspect.signature(self.func) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - return bound_args - - def __hash__(self) -> int: - return hash((self.func, self.return_dtype)) - def override_options( self, *, @@ -304,7 +256,7 @@ def override_options( num_gpus: float | None = _UnsetMarker, memory_bytes: int | None = _UnsetMarker, batch_size: int | None = _UnsetMarker, - ) -> StatelessUDF: + ) -> UDF: """Replace the resource requests for running each instance of your UDF. For instance, if your UDF requires 4 CPUs to run, you can configure it like so: @@ -312,15 +264,12 @@ def override_options( >>> import daft >>> >>> @daft.udf(return_dtype=daft.DataType.string()) - ... def example_stateless_udf(inputs): + ... def example_udf(inputs): ... # You will have access to 4 CPUs here if you configure your UDF correctly! ... return inputs >>> >>> # Parametrize the UDF to run with 4 CPUs - >>> example_stateless_udf_4CPU = example_stateless_udf.override_options(num_cpus=4) - >>> - >>> df = daft.from_pydict({"foo": [1, 2, 3]}) - >>> df = df.with_column("bar", example_stateless_udf_4CPU(df["foo"])) + >>> example_udf_4CPU = example_udf.override_options(num_cpus=4) Args: num_cpus: Number of CPUs to allocate each running instance of your UDF. Note that this is purely used for placement (e.g. if your @@ -331,152 +280,57 @@ def override_options( this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution. batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated. """ - new_common_args = self.common_args.override_options( - num_cpus=num_cpus, num_gpus=num_gpus, memory_bytes=memory_bytes, batch_size=batch_size - ) - return dataclasses.replace(self, common_args=new_common_args) - - -@dataclasses.dataclass -class StatefulUDF: - """A StatefulUDF is produced by calling `@udf` over a Python class, allowing for maintaining state between calls: it can be further parametrized at runtime with custom concurrency, resources, and init args. - - Example of a Stateful UDF: - >>> import daft - >>> - >>> @daft.udf(return_dtype=daft.DataType.string()) - ... class MyStatefulUdf: - ... def __init__(self, prefix: str = "Goodbye"): - ... self.prefix = prefix - ... - ... def __call__(self, name: daft.Series) -> list: - ... return [f"{self.prefix}, {n}!" for n in name.to_pylist()] - >>> - >>> MyHelloStatefulUdf = MyStatefulUdf.with_init_args(prefix="Hello") - >>> - >>> df = daft.from_pydict({"name": ["Alice", "Bob", "Charlie"]}) - >>> df = df.with_column("greeting", MyHelloStatefulUdf(df["name"])) - >>> df.show() - ╭─────────┬─────────────────╮ - │ name ┆ greeting │ - │ --- ┆ --- │ - │ Utf8 ┆ Utf8 │ - ╞═════════╪═════════════════╡ - │ Alice ┆ Hello, Alice! │ - ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ Bob ┆ Hello, Bob! │ - ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ Charlie ┆ Hello, Charlie! │ - ╰─────────┴─────────────────╯ - - (Showing first 3 of 3 rows) - - The state (in this case, the prefix) is maintained across calls to the UDF. Most commonly, this state is - used for things such as ML models which should be downloaded and loaded into memory once for multiple - invocations. - """ - - common_args: CommonUDFArgs - name: str - cls: type - return_dtype: DataType - init_args: InitArgsType = None - concurrency: int | None = None - - def __post_init__(self): - """Analogous to the @functools.wraps(self.cls) pattern - - This will swap out identifiers on `self` to match `self.cls`. Most notably, this swaps out - self.__module__ and self.__qualname__, which is used in `__reduce__` during serialization. - """ - functools.update_wrapper(self, self.cls) - - def __call__(self, *args, **kwargs) -> Expression: - """Call the UDF using some input Expressions, producing a new Expression that can be used by a DataFrame. - Args: - *args: Positional arguments to be passed to the UDF. These can be either Expressions or Python values. - **kwargs: Keyword arguments to be passed to the UDF. These can be either Expressions or Python values. - - Returns: - Expression: A new Expression representing the UDF call, which can be used in DataFrame operations. - - .. NOTE:: - When passing arguments to the UDF, you can use a mix of Expressions (e.g., df["column"]) and Python values. - Expressions will be evaluated for each row, while Python values will be passed as-is to the UDF. - - Example: - >>> import daft - >>> - >>> @daft.udf(return_dtype=daft.DataType.float64()) - ... class MultiplyAndAdd: - ... def __init__(self, multiplier: float = 2.0): - ... self.multiplier = multiplier - ... - ... def __call__(self, x: daft.Series, z: float) -> list: - ... return [val * self.multiplier + z for val in x.to_pylist()] - >>> - >>> df = daft.from_pydict({"x": [1, 2, 3]}) - >>> df = df.with_column("result", MultiplyAndAdd(df["x"], z=1.5)) - >>> df.show() - ╭───────┬─────────╮ - │ x ┆ result │ - │ --- ┆ --- │ - │ Int64 ┆ Float64 │ - ╞═══════╪═════════╡ - │ 1 ┆ 3.5 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ - │ 2 ┆ 5.5 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ - │ 3 ┆ 7.5 │ - ╰───────┴─────────╯ - - (Showing first 3 of 3 rows) - """ - # Validate that the UDF has a concurrency set, if running with actor pool projections - if get_context().daft_planning_config.enable_actor_pool_projections: - if self.concurrency is None: + new_resource_request = ResourceRequest() if self.resource_request is None else self.resource_request + if num_cpus is not _UnsetMarker: + new_resource_request = new_resource_request.with_num_cpus(num_cpus) + if num_gpus is not _UnsetMarker: + new_resource_request = new_resource_request.with_num_gpus(num_gpus) + if memory_bytes is not _UnsetMarker: + new_resource_request = new_resource_request.with_memory_bytes(memory_bytes) + + new_batch_size = self.batch_size if batch_size is _UnsetMarker else batch_size + + return dataclasses.replace(self, resource_request=new_resource_request, batch_size=new_batch_size) + + def _validate_init_args(self): + if isinstance(self.inner, type): + init_sig = inspect.signature(self.inner.__init__) # type: ignore + if ( + any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self") + and self.init_args is None + ): raise ValueError( - "Cannot call StatefulUDF without supplying a concurrency argument. Daft needs to know how many instances of your StatefulUDF to run concurrently. Please parametrize your UDF using `.with_concurrency(N)` before invoking it!" + "Cannot call class UDF without initialization arguments. Please either specify default arguments in your __init__ or provide " + "initialization arguments using `.with_init_args(...)`." ) - elif self.concurrency is not None: - raise ValueError( - "StatefulUDF cannot be run with concurrency specified without the experimental DAFT_ENABLE_ACTOR_POOL_PROJECTIONS=1 flag set." - ) - - # Validate that initialization arguments are provided if the __init__ signature indicates that there are - # parameters without defaults - init_sig = inspect.signature(self.cls.__init__) # type: ignore - if ( - any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self") - and self.init_args is None - ): - raise ValueError( - "Cannot call StatefulUDF without initialization arguments. Please either specify default arguments in your __init__ or provide " - "initialization arguments using `.with_init_args(...)`." + else: + if self.init_args is not None: + raise ValueError("Function UDFs cannot have init args.") + + def _bind_args(self, *args, **kwargs) -> BoundUDFArgs: + if isinstance(self.inner, type): + sig = inspect.signature(self.inner.__call__) + bound_args = sig.bind( + # Placeholder for `self` + None, + *args, + **kwargs, ) + else: + sig = inspect.signature(self.inner) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + return BoundUDFArgs(bound_args) - bound_args = BoundUDFArgs(self._bind_func(*args, **kwargs)) - expressions = list(bound_args.expressions().values()) - return Expression.stateful_udf( - name=self.name, - partial=PartialStatefulUDF(self.cls, self.return_dtype, bound_args), - expressions=expressions, - return_dtype=self.return_dtype, - resource_request=self.common_args.resource_request, - init_args=self.init_args, - batch_size=self.common_args.batch_size, - concurrency=self.concurrency, - ) - - def with_concurrency(self, concurrency: int) -> StatefulUDF: - """Override the concurrency of this StatefulUDF, which tells Daft how many instances of your StatefulUDF to run concurrently. + def with_concurrency(self, concurrency: int) -> UDF: + """Override the concurrency of this UDF, which tells Daft how many instances of your UDF to run concurrently. Example: >>> import daft >>> >>> @daft.udf(return_dtype=daft.DataType.string(), num_gpus=1) - ... class MyUDFThatNeedsAGPU: + ... class MyGpuUdf: ... def __init__(self, text=" world"): ... self.text = text ... @@ -484,12 +338,12 @@ def with_concurrency(self, concurrency: int) -> StatefulUDF: ... return [x + self.text for x in data.to_pylist()] >>> >>> # New UDF that will have 8 concurrent running instances (will require 8 total GPUs) - >>> MyUDFThatNeedsAGPU_8_concurrency = MyUDFThatNeedsAGPU.with_concurrency(8) + >>> MyGpuUdf_8_concurrency = MyGpuUdf.with_concurrency(8) """ return dataclasses.replace(self, concurrency=concurrency) - def with_init_args(self, *args, **kwargs) -> StatefulUDF: - """Replace initialization arguments for the Stateful UDF when calling `__init__` at runtime + def with_init_args(self, *args, **kwargs) -> UDF: + """Replace initialization arguments for a class UDF when calling `__init__` at runtime on each instance of the UDF. Example: @@ -497,19 +351,19 @@ def with_init_args(self, *args, **kwargs) -> StatefulUDF: >>> import daft >>> >>> @daft.udf(return_dtype=daft.DataType.string()) - ... class MyInitializedClass: + ... class MyUdfWithInit: ... def __init__(self, text=" world"): ... self.text = text ... ... def __call__(self, data): ... return [x + self.text for x in data.to_pylist()] >>> - >>> # Create a customized version of MyInitializedClass by overriding the init args - >>> MyInitializedClass_CustomInitArgs = MyInitializedClass.with_init_args(text=" my old friend") + >>> # Create a customized version of MyUdfWithInit by overriding the init args + >>> MyUdfWithInit_CustomInitArgs = MyUdfWithInit.with_init_args(text=" my old friend") >>> >>> df = daft.from_pydict({"foo": ["hello", "hello", "hello"]}) - >>> df = df.with_column("bar_world", MyInitializedClass(df["foo"])) - >>> df = df.with_column("bar_custom", MyInitializedClass_CustomInitArgs(df["foo"])) + >>> df = df.with_column("bar_world", MyUdfWithInit(df["foo"])) + >>> df = df.with_column("bar_custom", MyUdfWithInit_CustomInitArgs(df["foo"])) >>> df.show() ╭───────┬─────────────┬─────────────────────╮ │ foo ┆ bar_world ┆ bar_custom │ @@ -525,7 +379,10 @@ def with_init_args(self, *args, **kwargs) -> StatefulUDF: (Showing first 3 of 3 rows) """ - init_sig = inspect.signature(self.cls.__init__) # type: ignore + if not isinstance(self.inner, type): + raise ValueError("Function UDFs cannot have init args.") + + init_sig = inspect.signature(self.inner.__init__) # type: ignore init_sig.bind( # Placeholder for `self` None, @@ -534,58 +391,8 @@ def with_init_args(self, *args, **kwargs) -> StatefulUDF: ) return dataclasses.replace(self, init_args=(args, kwargs)) - def override_options( - self, - *, - num_cpus: float | None = _UnsetMarker, - num_gpus: float | None = _UnsetMarker, - memory_bytes: int | None = _UnsetMarker, - batch_size: int | None = _UnsetMarker, - ) -> StatefulUDF: - """Replace the resource requests for running each instance of your UDF. - - For instance, if your UDF requires 4 CPUs to run, you can configure it like so: - - >>> import daft - >>> - >>> @daft.udf(return_dtype=daft.DataType.string()) - ... def example_stateless_udf(inputs): - ... # You will have access to 4 CPUs here if you configure your UDF correctly! - ... return inputs - >>> - >>> # Parametrize the UDF to run with 4 CPUs - >>> example_stateless_udf_4CPU = example_stateless_udf.override_options(num_cpus=4) - >>> - >>> df = daft.from_pydict({"foo": [1, 2, 3]}) - >>> df = df.with_column("bar", example_stateless_udf_4CPU(df["foo"])) - - Args: - num_cpus: Number of CPUs to allocate each running instance of your UDF. Note that this is purely used for placement (e.g. if your - machine has 8 CPUs and you specify num_cpus=4, then Daft can run at most 2 instances of your UDF at a time). - num_gpus: Number of GPUs to allocate each running instance of your UDF. This is used for placement and also for allocating - the appropriate GPU to each UDF using `CUDA_VISIBLE_DEVICES`. - memory_bytes: Amount of memory to allocate each running instance of your UDF in bytes. If your UDF is experiencing out-of-memory errors, - this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution. - batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated. - """ - new_common_args = self.common_args.override_options( - num_cpus=num_cpus, num_gpus=num_gpus, memory_bytes=memory_bytes, batch_size=batch_size - ) - return dataclasses.replace(self, common_args=new_common_args) - - def _bind_func(self, *args, **kwargs) -> inspect.BoundArguments: - sig = inspect.signature(self.cls.__call__) - bound_args = sig.bind( - # Placeholder for `self` - None, - *args, - **kwargs, - ) - bound_args.apply_defaults() - return bound_args - def __hash__(self) -> int: - return hash((self.cls, self.return_dtype)) + return hash((self.inner, self.return_dtype)) def udf( @@ -595,8 +402,8 @@ def udf( num_gpus: float | None = None, memory_bytes: int | None = None, batch_size: int | None = None, -) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]: - """`@udf` Decorator to convert a Python function/class into a `StatelessUDF` or `StatefulUDF` respectively +) -> Callable[[UserDefinedPyFuncLike], UDF]: + """`@udf` Decorator to convert a Python function/class into a `UDF` UDFs allow users to run arbitrary Python code on the outputs of Expressions. @@ -705,15 +512,17 @@ def udf( batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated. Returns: - Callable[[UserProvidedPythonFunction], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions + Callable[[UserDefinedPyFuncLike], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions """ - def _udf(f: UserProvidedPythonFunction | type) -> StatelessUDF | StatefulUDF: + def _udf(f: UserDefinedPyFuncLike) -> UDF: # Grab a name for the UDF. It **should** be unique. - name = getattr(f, "__module__", "") # type: ignore[call-overload] - if name: - name = name + "." - name = name + getattr(f, "__qualname__") # type: ignore[call-overload] + module_name = getattr(f, "__module__", "") # type: ignore[call-overload] + qual_name = getattr(f, "__qualname__") # type: ignore[call-overload] + if module_name: + name = f"{module_name}.{qual_name}" + else: + name = qual_name resource_request = ( None @@ -725,25 +534,12 @@ def _udf(f: UserProvidedPythonFunction | type) -> StatelessUDF | StatefulUDF: ) ) - if inspect.isclass(f): - return StatefulUDF( - name=name, - cls=f, - return_dtype=return_dtype, - common_args=CommonUDFArgs( - resource_request=resource_request, - batch_size=batch_size, - ), - ) - else: - return StatelessUDF( - name=name, - func=f, - return_dtype=return_dtype, - common_args=CommonUDFArgs( - resource_request=resource_request, - batch_size=batch_size, - ), - ) + return UDF( + inner=f, + name=name, + return_dtype=return_dtype, + resource_request=resource_request, + batch_size=batch_size, + ) return _udf diff --git a/docs/source/api_docs/udf.rst b/docs/source/api_docs/udf.rst index 1e095eaadc..3f7354be90 100644 --- a/docs/source/api_docs/udf.rst +++ b/docs/source/api_docs/udf.rst @@ -7,8 +7,7 @@ A UDF can be used just like :doc:`Expressions <../user_guide/expressions>`, allo should be executed by Daft lazily. To write a UDF, you should use the :func:`@udf ` decorator, which can decorate either a Python -function or a Python class, producing either a :class:`StatelessUDF ` or -:class:`StatefulUDF ` respectively. +function or a Python class, producing a :class:`UDF `. For more details, please consult the :doc:`UDF User Guide <../user_guide/udf>` @@ -23,10 +22,6 @@ Creating UDFs Using UDFs ========== -.. autoclass:: daft.udf.StatelessUDF - :members: - :special-members: __call__ - -.. autoclass:: daft.udf.StatefulUDF +.. autoclass:: daft.udf.UDF :members: :special-members: __call__ diff --git a/docs/source/user_guide/udf.rst b/docs/source/user_guide/udf.rst index 7507351347..4b387239a7 100644 --- a/docs/source/user_guide/udf.rst +++ b/docs/source/user_guide/udf.rst @@ -143,7 +143,7 @@ Your UDF function itself needs to return a batch of columnar data, and can do so Note that if the data you have returned is not castable to the return_dtype that you specify (e.g. if you return a list of floats when you've specified a ``return_dtype=DataType.bool()``), Daft will throw a runtime error! -Stateful UDFs +Class UDFs ------------- UDFs can also be created on Classes, which allow for initialization on some expensive state that can be shared @@ -161,7 +161,7 @@ between invocations of the class, for example downloading data or creating a mod def __call__(self, features_col): return self._model(features_col) -Running Stateful UDFs are exactly the same as running their Stateless cousins. +Running Class UDFs are exactly the same as running their functional cousins. .. code:: python diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index ac8600936d..4aeae897a7 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -10,21 +10,12 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Serialize, Deserialize, Default, Debug)] pub struct DaftPlanningConfig { pub default_io_config: IOConfig, - pub enable_actor_pool_projections: bool, } impl DaftPlanningConfig { #[must_use] pub fn from_env() -> Self { - let mut cfg = Self::default(); - - let enable_actor_pool_projections_env_var_name = "DAFT_ENABLE_ACTOR_POOL_PROJECTIONS"; - if let Ok(val) = std::env::var(enable_actor_pool_projections_env_var_name) - && matches!(val.trim().to_lowercase().as_str(), "1" | "true") - { - cfg.enable_actor_pool_projections = true; - } - cfg + Default::default() } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index ad86de49ae..aceefd63d2 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -29,21 +29,13 @@ impl PyDaftPlanningConfig { } } - fn with_config_values( - &mut self, - default_io_config: Option, - enable_actor_pool_projections: Option, - ) -> PyResult { + fn with_config_values(&mut self, default_io_config: Option) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(default_io_config) = default_io_config { config.default_io_config = default_io_config.config; } - if let Some(enable_actor_pool_projections) = enable_actor_pool_projections { - config.enable_actor_pool_projections = enable_actor_pool_projections; - } - Ok(Self { config: Arc::new(config), }) @@ -55,11 +47,6 @@ impl PyDaftPlanningConfig { config: self.config.default_io_config.clone(), }) } - - #[getter(enable_actor_pool_projections)] - fn enable_actor_pool_projections(&self) -> PyResult { - Ok(self.config.enable_actor_pool_projections) - } } impl_bincode_py_state_serialization!(PyDaftPlanningConfig); diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 84a6791371..de41b34cfa 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1273,14 +1273,35 @@ pub fn has_agg(expr: &ExprRef) -> bool { expr.exists(|e| matches!(e.as_ref(), Expr::Agg(_))) } -pub fn has_stateful_udf(expr: &ExprRef) -> bool { - expr.exists(|e| { - matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(_)), +#[inline] +pub fn is_actor_pool_udf(expr: &ExprRef) -> bool { + matches!( + expr.as_ref(), + Expr::Function { + func: FunctionExpr::Python(PythonUDF { + concurrency: Some(_), .. - } - ) - }) + }), + .. + } + ) +} + +pub fn count_actor_pool_udfs(exprs: &[ExprRef]) -> usize { + exprs + .iter() + .map(|expr| { + let mut count = 0; + expr.apply(|e| { + if is_actor_pool_udf(e) { + count += 1; + } + + Ok(common_treenode::TreeNodeRecursion::Continue) + }) + .unwrap(); + + count + }) + .sum() } diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 1fd4401228..7fa3bd8952 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -47,7 +47,7 @@ impl FunctionExpr { Self::Map(expr) => expr.get_evaluator(), Self::Sketch(expr) => expr.get_evaluator(), Self::Struct(expr) => expr.get_evaluator(), - Self::Python(expr) => expr.get_evaluator(), + Self::Python(expr) => expr, Self::Partitioning(expr) => expr.get_evaluator(), } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 88af7d8150..917345d632 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -1,110 +1,63 @@ mod runtime_py_object; mod udf; -mod udf_runtime_binding; -#[cfg(feature = "python")] -use std::collections::HashMap; use std::sync::Arc; -#[cfg(feature = "python")] -use common_error::DaftError; use common_error::DaftResult; use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; -use daft_core::datatypes::DataType; +use daft_core::prelude::*; use itertools::Itertools; -#[cfg(feature = "python")] -use pyo3::{Py, PyAny}; pub use runtime_py_object::RuntimePyObject; use serde::{Deserialize, Serialize}; -pub use udf_runtime_binding::UDFRuntimeBinding; -use super::{FunctionEvaluator, FunctionExpr}; +use super::FunctionExpr; use crate::{Expr, ExprRef}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum PythonUDF { - Stateless(StatelessPythonUDF), - Stateful(StatefulPythonUDF), -} - -impl PythonUDF { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - match self { - Self::Stateless(stateless_python_udf) => stateless_python_udf, - Self::Stateful(stateful_python_udf) => stateful_python_udf, - } - } +pub enum MaybeInitializedUDF { + Initialized(RuntimePyObject), + Uninitialized { + inner: RuntimePyObject, + init_args: RuntimePyObject, + }, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct StatelessPythonUDF { +pub struct PythonUDF { pub name: Arc, - partial_func: RuntimePyObject, - num_expressions: usize, - pub return_dtype: DataType, - pub resource_request: Option, - pub batch_size: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct StatefulPythonUDF { - pub name: Arc, - pub stateful_partial_func: RuntimePyObject, + pub func: MaybeInitializedUDF, + pub bound_args: RuntimePyObject, pub num_expressions: usize, pub return_dtype: DataType, pub resource_request: Option, - pub init_args: Option, pub batch_size: Option, pub concurrency: Option, - pub runtime_binding: UDFRuntimeBinding, -} - -pub fn stateless_udf( - name: &str, - py_partial_stateless_udf: RuntimePyObject, - expressions: &[ExprRef], - return_dtype: DataType, - resource_request: Option, - batch_size: Option, -) -> DaftResult { - Ok(Expr::Function { - func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF { - name: name.to_string().into(), - partial_func: py_partial_stateless_udf, - num_expressions: expressions.len(), - return_dtype, - resource_request, - batch_size, - })), - inputs: expressions.into(), - }) } #[allow(clippy::too_many_arguments)] -pub fn stateful_udf( +pub fn udf( name: &str, - py_stateful_partial_func: RuntimePyObject, + inner: RuntimePyObject, + bound_args: RuntimePyObject, expressions: &[ExprRef], return_dtype: DataType, + init_args: RuntimePyObject, resource_request: Option, - init_args: Option, batch_size: Option, concurrency: Option, ) -> DaftResult { Ok(Expr::Function { - func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + func: super::FunctionExpr::Python(PythonUDF { name: name.to_string().into(), - stateful_partial_func: py_stateful_partial_func, + func: MaybeInitializedUDF::Uninitialized { inner, init_args }, + bound_args, num_expressions: expressions.len(), return_dtype, resource_request, - init_args, batch_size, concurrency, - runtime_binding: UDFRuntimeBinding::Unbound, - })), + }), inputs: expressions.into(), }) } @@ -119,18 +72,9 @@ pub fn get_resource_request(exprs: &[ExprRef]) -> Option { expr.apply(|e| match e.as_ref() { Expr::Function { func: - FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF { - resource_request, - .. - })), - .. - } - | Expr::Function { - func: - FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - resource_request, - .. - })), + FunctionExpr::Python(PythonUDF { + resource_request, .. + }), .. } => { if let Some(rr) = resource_request { @@ -157,101 +101,119 @@ pub fn get_resource_request(exprs: &[ExprRef]) -> Option { } } -/// Gets the concurrency from the first StatefulUDF encountered in a given slice of expressions +/// Gets the concurrency from the first UDF encountered in a given slice of expressions /// -/// NOTE: This function panics if no StatefulUDF is found +/// NOTE: This function panics if no UDF is found or if the first UDF has no concurrency pub fn get_concurrency(exprs: &[ExprRef]) -> usize { let mut projection_concurrency = None; for expr in exprs { - let mut found_stateful_udf = false; + let mut found_actor_pool_udf = false; expr.apply(|e| match e.as_ref() { Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF{concurrency, ..})), + func: FunctionExpr::Python(PythonUDF { concurrency, .. }), .. } => { - found_stateful_udf = true; - projection_concurrency = Some(concurrency.expect("Should have concurrency specified")); + found_actor_pool_udf = true; + projection_concurrency = + Some(concurrency.expect("Should have concurrency specified")); Ok(common_treenode::TreeNodeRecursion::Stop) } _ => Ok(common_treenode::TreeNodeRecursion::Continue), - }).unwrap(); - if found_stateful_udf { + }) + .unwrap(); + if found_actor_pool_udf { break; } } - projection_concurrency.expect("get_concurrency expects one StatefulUDF") + projection_concurrency.expect("get_concurrency expects one UDF with concurrency set") +} + +/// Gets the batch size from the first UDF encountered in a given slice of expressions +pub fn get_batch_size(exprs: &[ExprRef]) -> Option { + let mut projection_batch_size = None; + for expr in exprs { + let mut found_udf = false; + expr.apply(|e| match e.as_ref() { + Expr::Function { + func: FunctionExpr::Python(PythonUDF { batch_size, .. }), + .. + } => { + found_udf = true; + projection_batch_size = Some(*batch_size); + Ok(common_treenode::TreeNodeRecursion::Stop) + } + _ => Ok(common_treenode::TreeNodeRecursion::Continue), + }) + .unwrap(); + if found_udf { + break; + } + } + projection_batch_size.expect("get_batch_size expects one UDF") } -/// Binds every StatefulPythonUDF expression to an initialized function provided by an actor #[cfg(feature = "python")] -pub fn bind_stateful_udfs( - expr: ExprRef, - initialized_funcs: &HashMap>, -) -> DaftResult { +fn py_udf_initialize( + func: pyo3::PyObject, + init_args: pyo3::PyObject, +) -> DaftResult { + use pyo3::Python; + + Ok(Python::with_gil(move |py| { + func.call_method1(py, pyo3::intern!(py, "initialize"), (init_args,)) + })?) +} + +/// Initializes all uninitialized UDFs in the expression +#[cfg(feature = "python")] +pub fn initialize_udfs(expr: ExprRef) -> DaftResult { + use common_treenode::Transformed; + expr.transform(|e| match e.as_ref() { Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(stateful_py_udf)), + func: + FunctionExpr::Python( + python_udf @ PythonUDF { + func: MaybeInitializedUDF::Uninitialized { inner, init_args }, + .. + }, + ), inputs, } => { - let f = initialized_funcs - .get(stateful_py_udf.name.as_ref()) - .ok_or_else(|| { - DaftError::InternalError(format!( - "Unable to find UDF to bind: {}", - stateful_py_udf.name.as_ref() - )) - })?; - let bound_expr = Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - runtime_binding: udf_runtime_binding::UDFRuntimeBinding::Bound( - f.clone().into(), - ), - ..stateful_py_udf.clone() - })), + let initialized_func = + py_udf_initialize(inner.clone().unwrap(), init_args.clone().unwrap())?; + + let initialized_expr = Expr::Function { + func: FunctionExpr::Python(PythonUDF { + func: MaybeInitializedUDF::Initialized(initialized_func.into()), + ..python_udf.clone() + }), inputs: inputs.clone(), }; - Ok(common_treenode::Transformed::yes(bound_expr.into())) + + Ok(Transformed::yes(initialized_expr.into())) } - _ => Ok(common_treenode::Transformed::no(e)), + _ => Ok(Transformed::no(e)), }) .map(|transformed| transformed.data) } -/// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree -#[cfg(feature = "python")] -pub fn extract_partial_stateful_udf_py( - expr: ExprRef, -) -> HashMap, Option>)> { - extract_stateful_udf_exprs(expr) - .into_iter() - .map(|stateful_udf| { - ( - stateful_udf.name.as_ref().to_string(), - ( - stateful_udf.stateful_partial_func.as_ref().clone(), - stateful_udf.init_args.map(|x| x.as_ref().clone()), - ), - ) - }) - .collect() -} - -/// Helper function that extracts all `StatefulPythonUDF` expressions from a given expression tree -pub fn extract_stateful_udf_exprs(expr: ExprRef) -> Vec { - let mut stateful_udf_exprs = Vec::new(); +/// Get the names of all UDFs in expression +pub fn get_udf_names(expr: &ExprRef) -> Vec { + let mut names = Vec::new(); - expr.apply(|child| { + expr.apply(|e| { if let Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(stateful_udf)), + func: FunctionExpr::Python(PythonUDF { name, .. }), .. - } = child.as_ref() + } = e.as_ref() { - stateful_udf_exprs.push(stateful_udf.clone()); + names.push(name.to_string()); } Ok(TreeNodeRecursion::Continue) }) .unwrap(); - stateful_udf_exprs + names } diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 100bd06566..e08257cdd9 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -6,47 +6,9 @@ use pyo3::{ Bound, PyAny, PyResult, }; -use super::{super::FunctionEvaluator, StatefulPythonUDF, StatelessPythonUDF}; +use super::{super::FunctionEvaluator, PythonUDF}; use crate::{functions::FunctionExpr, ExprRef}; -impl FunctionEvaluator for StatelessPythonUDF { - fn fn_name(&self) -> &'static str { - "py_udf" - } - - fn to_field( - &self, - inputs: &[ExprRef], - _schema: &Schema, - _: &FunctionExpr, - ) -> DaftResult { - if inputs.len() != self.num_expressions { - return Err(DaftError::SchemaMismatch(format!( - "Number of inputs required by UDF {} does not match number of inputs provided: {}", - self.num_expressions, - inputs.len() - ))); - } - match inputs { - [] => Err(DaftError::ValueError( - "Cannot run UDF with 0 expression arguments".into(), - )), - [first, ..] => Ok(Field::new(first.name(), self.return_dtype.clone())), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - #[cfg(not(feature = "python"))] - { - panic!("Cannot evaluate a StatelessPythonUDF without compiling for Python"); - } - #[cfg(feature = "python")] - { - self.call_udf(inputs) - } - } -} - #[cfg(feature = "python")] fn run_udf( py: pyo3::Python, @@ -96,11 +58,13 @@ fn run_udf( } } -impl StatelessPythonUDF { +impl PythonUDF { #[cfg(feature = "python")] pub fn call_udf(&self, inputs: &[Series]) -> DaftResult { use pyo3::Python; + use crate::functions::python::{py_udf_initialize, MaybeInitializedUDF}; + if inputs.len() != self.num_expressions { return Err(DaftError::SchemaMismatch(format!( "Number of inputs required by UDF {} does not match number of inputs provided: {}", @@ -109,22 +73,21 @@ impl StatelessPythonUDF { ))); } - Python::with_gil(|py| { - // Extract the required Python objects to call our run_udf helper - let func = self - .partial_func - .as_ref() - .getattr(py, pyo3::intern!(py, "func"))?; - let bound_args = self - .partial_func - .as_ref() - .getattr(py, pyo3::intern!(py, "bound_args"))?; + let func = match &self.func { + MaybeInitializedUDF::Initialized(func) => func.clone().unwrap(), + MaybeInitializedUDF::Uninitialized { inner, init_args } => { + // TODO(Kevin): warn user if initialization is taking too long and ask them to use actor pool UDFs + + py_udf_initialize(inner.clone().unwrap(), init_args.clone().unwrap())? + } + }; + Python::with_gil(|py| { run_udf( py, inputs, func, - bound_args, + self.bound_args.clone().unwrap(), &self.return_dtype, self.batch_size, ) @@ -132,17 +95,12 @@ impl StatelessPythonUDF { } } -impl FunctionEvaluator for StatefulPythonUDF { +impl FunctionEvaluator for PythonUDF { fn fn_name(&self) -> &'static str { - "pyclass_udf" + "py_udf" } - fn to_field( - &self, - inputs: &[ExprRef], - _schema: &Schema, - _: &FunctionExpr, - ) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], _: &Schema, _: &FunctionExpr) -> DaftResult { if inputs.len() != self.num_expressions { return Err(DaftError::SchemaMismatch(format!( "Number of inputs required by UDF {} does not match number of inputs provided: {}", @@ -161,90 +119,11 @@ impl FunctionEvaluator for StatefulPythonUDF { fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { #[cfg(not(feature = "python"))] { - panic!("Cannot evaluate a StatelessPythonUDF without compiling for Python"); + panic!("Cannot evaluate a PythonUDF without compiling for Python"); } - #[cfg(feature = "python")] { - use pyo3::{ - types::{PyDict, PyTuple}, - Python, - }; - - use crate::functions::python::udf_runtime_binding::UDFRuntimeBinding; - - if inputs.len() != self.num_expressions { - return Err(DaftError::SchemaMismatch(format!( - "Number of inputs required by UDF {} does not match number of inputs provided: {}", - self.num_expressions, - inputs.len() - ))); - } - - if let UDFRuntimeBinding::Bound(func) = &self.runtime_binding { - Python::with_gil(|py| { - // Extract the required Python objects to call our run_udf helper - let bound_args = self - .stateful_partial_func - .as_ref() - .getattr(py, pyo3::intern!(py, "bound_args"))?; - run_udf( - py, - inputs, - pyo3::Py::clone_ref(func.as_ref(), py), - bound_args, - &self.return_dtype, - self.batch_size, - ) - }) - } else { - // NOTE: This branch of evaluation performs a naive initialization of the class. It is performed once-per-evaluate - // which is not ideal. Callers trying to .evaluate a StatefulPythonUDF should first bind it to initialized classes. - Python::with_gil(|py| { - // Extract the required Python objects to call our run_udf helper - let func = self - .stateful_partial_func - .as_ref() - .getattr(py, pyo3::intern!(py, "func_cls"))?; - let bound_args = self - .stateful_partial_func - .as_ref() - .getattr(py, pyo3::intern!(py, "bound_args"))?; - - let func = match &self.init_args { - None => func.call0(py)?, - Some(init_args) => { - let init_args = init_args - .as_ref() - .bind(py) - .downcast::() - .expect("init_args should be a Python tuple"); - let (args, kwargs) = ( - init_args - .get_item(0)? - .downcast::() - .expect("init_args[0] should be a tuple of *args") - .clone(), - init_args - .get_item(1)? - .downcast::() - .expect("init_args[1] should be a dict of **kwargs") - .clone(), - ); - func.call_bound(py, args, Some(&kwargs))? - } - }; - - run_udf( - py, - inputs, - func, - bound_args, - &self.return_dtype, - self.batch_size, - ) - }) - } + self.call_udf(inputs) } } } diff --git a/src/daft-dsl/src/functions/python/udf_runtime_binding.rs b/src/daft-dsl/src/functions/python/udf_runtime_binding.rs deleted file mode 100644 index 6425f61784..0000000000 --- a/src/daft-dsl/src/functions/python/udf_runtime_binding.rs +++ /dev/null @@ -1,72 +0,0 @@ -use std::hash::{Hash, Hasher}; - -use serde::{de::Visitor, Deserialize, Serialize}; - -use super::RuntimePyObject; - -/// A binding between the StatefulPythonUDF and an initialized Python callable -/// -/// This is `Unbound` during planning, and bound to an initialized Python callable -/// by an Actor right before execution. -/// -/// Note that attempting to Hash, Eq, Serde this when it is bound will panic! -#[derive(Debug, Clone)] -pub enum UDFRuntimeBinding { - Unbound, - Bound(RuntimePyObject), -} - -impl PartialEq for UDFRuntimeBinding { - fn eq(&self, other: &Self) -> bool { - matches!((self, other), (Self::Unbound, Self::Unbound)) - } -} - -impl Eq for UDFRuntimeBinding {} - -impl Hash for UDFRuntimeBinding { - fn hash(&self, state: &mut H) { - match self { - Self::Unbound => state.write_u8(0), - Self::Bound(_) => panic!("Cannot hash a bound UDFRuntimeBinding."), - } - } -} - -impl Serialize for UDFRuntimeBinding { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self { - Self::Unbound => serializer.serialize_unit(), - Self::Bound(_) => panic!("Cannot serialize a bound UDFRuntimeBinding."), - } - } -} - -struct UDFRuntimeBindingVisitor; - -impl<'de> Visitor<'de> for UDFRuntimeBindingVisitor { - type Value = UDFRuntimeBinding; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("unit") - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(UDFRuntimeBinding::Unbound) - } -} - -impl<'de> Deserialize<'de> for UDFRuntimeBinding { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_unit(UDFRuntimeBindingVisitor) - } -} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 6c84bbe0e2..4c92027e9a 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -15,9 +15,9 @@ mod resolve_expr; mod treenode; pub use common_treenode; pub use expr::{ - binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr, - ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType, Subquery, - SubqueryPlan, + binary_op, col, count_actor_pool_udfs, has_agg, is_actor_pool_udf, is_partition_compatible, + AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType, + Subquery, SubqueryPlan, }; pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] @@ -37,13 +37,9 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(python::interval_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?; - parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?; - parent.add_function(wrap_pyfunction_bound!(python::stateful_udf, parent)?)?; - parent.add_function(wrap_pyfunction_bound!( - python::extract_partial_stateful_udf_py, - parent - )?)?; - parent.add_function(wrap_pyfunction_bound!(python::bind_stateful_udfs, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::udf, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::initialize_udfs, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::get_udf_names, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::eq, parent)?)?; parent.add_function(wrap_pyfunction_bound!( python::check_column_name_validity, diff --git a/src/daft-dsl/src/pyobj_serde.rs b/src/daft-dsl/src/pyobj_serde.rs index 8abdfc76bd..259f9dc7d8 100644 --- a/src/daft-dsl/src/pyobj_serde.rs +++ b/src/daft-dsl/src/pyobj_serde.rs @@ -7,7 +7,7 @@ use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::{types::PyAnyMethods, PyObject, Python}; use serde::{Deserialize, Serialize}; -// This is a Rust wrapper on top of a Python PartialStatelessUDF or PartialStatefulUDF to make it serde-able and hashable +// This is a Rust wrapper on top of a Python UDF to make it serde-able and hashable #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PyObjectWrapper( #[serde( diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index f350ef5ff8..928e3ced13 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] use std::{ - collections::{hash_map::DefaultHasher, HashMap}, + collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, sync::Arc, }; @@ -173,60 +173,20 @@ pub fn lit(item: Bound) -> PyResult { } } -// Create a UDF Expression using: -// * `func` - a Python function that takes as input an ordered list of Python Series to execute the user's UDF. -// * `expressions` - an ordered list of Expressions, each representing computation that will be performed, producing a Series to pass into `func` -// * `return_dtype` - returned column's DataType -#[pyfunction] -pub fn stateless_udf( - name: &str, - partial_stateless_udf: PyObject, - expressions: Vec, - return_dtype: PyDataType, - resource_request: Option, - batch_size: Option, -) -> PyResult { - use crate::functions::python::stateless_udf; - - if let Some(batch_size) = batch_size { - if batch_size == 0 { - return Err(PyValueError::new_err(format!( - "Error creating UDF: batch size must be positive (got {batch_size})" - ))); - } - } - - let expressions_map: Vec = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect(); - Ok(PyExpr { - expr: stateless_udf( - name, - partial_stateless_udf.into(), - &expressions_map, - return_dtype.dtype, - resource_request, - batch_size, - )? - .into(), - }) -} - -// Create a UDF Expression using: -// * `cls` - a Python class that has an __init__, and where __call__ takes as input an ordered list of Python Series to execute the user's UDF. -// * `expressions` - an ordered list of Expressions, each representing computation that will be performed, producing a Series to pass into `func` -// * `return_dtype` - returned column's DataType #[pyfunction] #[allow(clippy::too_many_arguments)] -pub fn stateful_udf( +pub fn udf( name: &str, - partial_stateful_udf: PyObject, + inner: PyObject, + bound_args: PyObject, expressions: Vec, return_dtype: PyDataType, + init_args: PyObject, resource_request: Option, - init_args: Option, batch_size: Option, concurrency: Option, ) -> PyResult { - use crate::functions::python::stateful_udf; + use crate::functions::python::udf; if let Some(batch_size) = batch_size { if batch_size == 0 { @@ -237,15 +197,15 @@ pub fn stateful_udf( } let expressions_map: Vec = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect(); - let init_args = init_args.map(|args| args.into()); Ok(PyExpr { - expr: stateful_udf( + expr: udf( name, - partial_stateful_udf.into(), + inner.into(), + bound_args.into(), &expressions_map, return_dtype.dtype, + init_args.into(), resource_request, - init_args, batch_size, concurrency, )? @@ -253,23 +213,18 @@ pub fn stateful_udf( }) } -/// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree +/// Initializes all uninitialized UDFs in the expression #[pyfunction] -pub fn extract_partial_stateful_udf_py( - expr: PyExpr, -) -> HashMap, Option>)> { - use crate::functions::python::extract_partial_stateful_udf_py; - extract_partial_stateful_udf_py(expr.expr) +pub fn initialize_udfs(expr: PyExpr) -> PyResult { + use crate::functions::python::initialize_udfs; + Ok(initialize_udfs(expr.expr)?.into()) } -/// Binds the StatefulPythonUDFs in a given expression to any corresponding initialized Python callables in the provided map +/// Get the names of all UDFs in expression #[pyfunction] -pub fn bind_stateful_udfs( - expr: PyExpr, - initialized_funcs: HashMap>, -) -> PyResult { - use crate::functions::python::bind_stateful_udfs; - Ok(bind_stateful_udfs(expr.expr, &initialized_funcs).map(PyExpr::from)?) +pub fn get_udf_names(expr: PyExpr) -> Vec { + use crate::functions::python::get_udf_names; + get_udf_names(&expr.expr) } #[pyclass(module = "daft.daft")] diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-dsl/src/resolve_expr/mod.rs index 96bbf90513..35d97bc9a8 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -13,7 +13,7 @@ use daft_core::prelude::*; use typed_builder::TypedBuilder; use crate::{ - col, expr::has_agg, functions::FunctionExpr, has_stateful_udf, AggExpr, Expr, ExprRef, + col, expr::has_agg, functions::FunctionExpr, is_actor_pool_udf, AggExpr, Expr, ExprRef, }; // Calculates all the possible struct get expressions in a schema. @@ -223,12 +223,12 @@ fn convert_udfs_to_map_groups(expr: &ExprRef) -> ExprRef { } /// Used for resolving and validating expressions. -/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs +/// Specifically, makes sure the expression does not contain aggregations or actor pool UDFs /// where they are not allowed, and resolves struct accessors and wildcards. #[derive(Default, TypedBuilder)] pub struct ExprResolver<'a> { #[builder(default)] - allow_stateful_udf: bool, + allow_actor_pool_udf: bool, #[builder(via_mutators, mutators( pub fn in_agg_context(&mut self, in_agg_context: bool) { // workaround since typed_builder can't have defaults for mutator requirements @@ -247,9 +247,9 @@ pub struct ExprResolver<'a> { impl<'a> ExprResolver<'a> { fn resolve_helper(&self, expr: ExprRef, schema: &Schema) -> DaftResult> { - if !self.allow_stateful_udf && has_stateful_udf(&expr) { + if !self.allow_actor_pool_udf && expr.exists(is_actor_pool_udf) { return Err(DaftError::ValueError(format!( - "Stateful UDFs are only allowed in projections: {expr}" + "UDFs with concurrency set are only allowed in projections: {expr}" ))); } diff --git a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs index 810c6c7560..92ba1de5f7 100644 --- a/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs +++ b/src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs @@ -4,7 +4,11 @@ use common_error::DaftResult; use common_runtime::RuntimeRef; #[cfg(feature = "python")] use daft_dsl::python::PyExpr; -use daft_dsl::{functions::python::extract_stateful_udf_exprs, ExprRef}; +use daft_dsl::{ + count_actor_pool_udfs, + functions::python::{get_batch_size, get_concurrency}, + ExprRef, +}; #[cfg(feature = "python")] use daft_micropartition::python::PyMicroPartition; use daft_micropartition::MicroPartition; @@ -33,8 +37,8 @@ impl ActorHandle { let handle = Python::with_gil(|py| { // create python object Ok::( - py.import_bound(pyo3::intern!(py, "daft.execution.stateful_actor"))? - .getattr(pyo3::intern!(py, "StatefulActorHandle"))? + py.import_bound(pyo3::intern!(py, "daft.execution.actor_pool_udf"))? + .getattr(pyo3::intern!(py, "ActorHandle"))? .call1((projection .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -70,7 +74,7 @@ impl ActorHandle { #[cfg(not(feature = "python"))] { - panic!("Cannot evaluate a stateful UDF without compiling for Python"); + panic!("Cannot evaluate a UDF without compiling for Python"); } } @@ -98,7 +102,7 @@ impl Drop for ActorHandle { let result = self.teardown(); if let Err(e) = result { - log::error!("Error tearing down stateful UDF actor: {}", e); + log::error!("Error tearing down UDF actor: {}", e); } } } @@ -126,21 +130,20 @@ pub struct ActorPoolProjectOperator { impl ActorPoolProjectOperator { pub fn new(projection: Vec) -> Self { - let stateful_udf_vec = projection - .iter() - .flat_map(|expr| extract_stateful_udf_exprs(expr.clone())) - .collect::>(); + let num_actor_pool_udfs: usize = count_actor_pool_udfs(&projection); + + assert_eq!( + num_actor_pool_udfs, 1, + "Expected only one actor pool udf in an actor pool project" + ); - let [stateful_udf] = stateful_udf_vec - .try_into() - .expect("Expected only one stateful udf in an actor pool project"); + let concurrency = get_concurrency(&projection); + let batch_size = get_batch_size(&projection); Self { projection, - concurrency: stateful_udf - .concurrency - .expect("Stateful UDF should have concurrency"), - batch_size: stateful_udf.batch_size, + concurrency, + batch_size, } } } diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 0daaf79460..8e7936a1a8 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -26,7 +26,7 @@ use { use crate::{ logical_plan::LogicalPlan, ops, - optimization::{Optimizer, OptimizerConfig}, + optimization::Optimizer, partitioning::{ HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec, }, @@ -590,16 +590,7 @@ impl LogicalPlanBuilder { } pub fn optimize(&self) -> DaftResult { - let default_optimizer_config: OptimizerConfig = Default::default(); - let optimizer_config = OptimizerConfig { - enable_actor_pool_projections: self - .config - .as_ref() - .map(|planning_cfg| planning_cfg.enable_actor_pool_projections) - .unwrap_or(default_optimizer_config.enable_actor_pool_projections), - ..default_optimizer_config - }; - let optimizer = Optimizer::new(optimizer_config); + let optimizer = Optimizer::new(Default::default()); // Run LogicalPlan optimizations let unoptimized_plan = self.build(); diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index 78ec2a681f..d9a2aa0c4b 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -2,13 +2,10 @@ use std::sync::Arc; use common_error::DaftError; use common_resource_request::ResourceRequest; -use common_treenode::TreeNode; use daft_dsl::{ - functions::{ - python::{get_concurrency, get_resource_request, PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, - Expr, ExprRef, ExprResolver, + count_actor_pool_udfs, + functions::python::{get_concurrency, get_resource_request, get_udf_names}, + ExprRef, ExprResolver, }; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; @@ -31,33 +28,14 @@ pub struct ActorPoolProject { impl ActorPoolProject { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_stateful_udf(true).build(); + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); let (projection, fields) = expr_resolver .resolve(projection, input.schema().as_ref()) .context(CreationSnafu)?; - let num_stateful_udf_exprs: usize = projection - .iter() - .map(|expr| { - let mut num_stateful_udfs = 0; - expr.apply(|e| { - if matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(_)), - .. - } - ) { - num_stateful_udfs += 1; - } - Ok(common_treenode::TreeNodeRecursion::Continue) - }) - .unwrap(); - num_stateful_udfs - }) - .sum(); - if !num_stateful_udf_exprs == 1 { - return Err(Error::CreationError { source: DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 stateful UDF expression but found: {num_stateful_udf_exprs}")) }); + let num_actor_pool_udfs: usize = count_actor_pool_udfs(&projection); + if !num_actor_pool_udfs == 1 { + return Err(Error::CreationError { source: DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 actor pool UDF expression but found: {num_actor_pool_udfs}")) }); } let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); @@ -94,28 +72,7 @@ impl ActorPoolProject { )); res.push(format!( "UDFs = [{}]", - self.projection - .iter() - .flat_map(|proj| { - let mut udf_names = vec![]; - proj.apply(|e| { - if let Expr::Function { - func: - FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - name, - .. - })), - .. - } = e.as_ref() - { - udf_names.push(name.clone()); - } - Ok(common_treenode::TreeNodeRecursion::Continue) - }) - .unwrap(); - udf_names - }) - .join(", ") + self.projection.iter().flat_map(get_udf_names).join(", ") )); res.push(format!("Concurrency = {}", self.concurrency())); if let Some(resource_request) = self.resource_request() { diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index c8c417f6e1..165d989a09 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -24,7 +24,7 @@ pub struct Project { impl Project { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_stateful_udf(true).build(); + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); let (projection, fields) = expr_resolver .resolve(projection, &input.schema()) diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 61a3ff314e..7c68274d95 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -17,15 +17,12 @@ use crate::LogicalPlan; pub struct OptimizerConfig { // Default maximum number of optimization passes the optimizer will make over a fixed-point RuleBatch. pub default_max_optimizer_passes: usize, - // Feature flag for enabling creating ActorPoolProject nodes during plan optimization - pub enable_actor_pool_projections: bool, } impl OptimizerConfig { - fn new(max_optimizer_passes: usize, enable_actor_pool_projections: bool) -> Self { + fn new(max_optimizer_passes: usize) -> Self { Self { default_max_optimizer_passes: max_optimizer_passes, - enable_actor_pool_projections, } } } @@ -33,7 +30,7 @@ impl OptimizerConfig { impl Default for OptimizerConfig { fn default() -> Self { // Default to a max of 5 optimizer passes for a given batch. - Self::new(5, false) + Self::new(5) } } @@ -91,62 +88,48 @@ pub struct Optimizer { impl Optimizer { pub fn new(config: OptimizerConfig) -> Self { - let mut rule_batches = Vec::new(); - - // --- Split ActorPoolProjection nodes from Project nodes --- - // This is feature-flagged behind DAFT_ENABLE_ACTOR_POOL_PROJECTIONS=1 - if config.enable_actor_pool_projections { - rule_batches.push(RuleBatch::new( + let rule_batches = vec![ + // --- Rewrite rules --- + RuleBatch::new( vec![ - Box::new(PushDownProjection::new()), Box::new(SplitActorPoolProjects::new()), + Box::new(LiftProjectFromAgg::new()), + ], + RuleExecutionStrategy::Once, + ), + // --- Bulk of our rules --- + RuleBatch::new( + vec![ + Box::new(DropRepartition::new()), + Box::new(PushDownFilter::new()), Box::new(PushDownProjection::new()), + Box::new(EliminateCrossJoin::new()), ], + // Use a fixed-point policy for the pushdown rules: PushDownProjection can produce a Filter node + // at the current node, which would require another batch application in order to have a chance to push + // that Filter node through upstream nodes. + // TODO(Clark): Refine this fixed-point policy. + RuleExecutionStrategy::FixedPoint(Some(3)), + ), + // --- Limit pushdowns --- + // This needs to be separate from PushDownProjection because otherwise the limit and + // projection just keep swapping places, preventing optimization + // (see https://github.com/Eventual-Inc/Daft/issues/2616) + RuleBatch::new( + vec![Box::new(PushDownLimit::new())], + RuleExecutionStrategy::FixedPoint(Some(3)), + ), + // --- Materialize scan nodes --- + RuleBatch::new( + vec![Box::new(MaterializeScans::new())], RuleExecutionStrategy::Once, - )); - } - - // --- Rewrite rules --- - rule_batches.push(RuleBatch::new( - vec![Box::new(LiftProjectFromAgg::new())], - RuleExecutionStrategy::Once, - )); - - // --- Bulk of our rules --- - rule_batches.push(RuleBatch::new( - vec![ - Box::new(DropRepartition::new()), - Box::new(PushDownFilter::new()), - Box::new(PushDownProjection::new()), - Box::new(EliminateCrossJoin::new()), - ], - // Use a fixed-point policy for the pushdown rules: PushDownProjection can produce a Filter node - // at the current node, which would require another batch application in order to have a chance to push - // that Filter node through upstream nodes. - // TODO(Clark): Refine this fixed-point policy. - RuleExecutionStrategy::FixedPoint(Some(3)), - )); - - // --- Limit pushdowns --- - // This needs to be separate from PushDownProjection because otherwise the limit and - // projection just keep swapping places, preventing optimization - // (see https://github.com/Eventual-Inc/Daft/issues/2616) - rule_batches.push(RuleBatch::new( - vec![Box::new(PushDownLimit::new())], - RuleExecutionStrategy::FixedPoint(Some(3)), - )); - - // --- Materialize scan nodes --- - rule_batches.push(RuleBatch::new( - vec![Box::new(MaterializeScans::new())], - RuleExecutionStrategy::Once, - )); - - // --- Enrich logical plan with stats --- - rule_batches.push(RuleBatch::new( - vec![Box::new(EnrichWithStats::new())], - RuleExecutionStrategy::Once, - )); + ), + // --- Enrich logical plan with stats --- + RuleBatch::new( + vec![Box::new(EnrichWithStats::new())], + RuleExecutionStrategy::Once, + ), + ]; Self::with_rule_batches(rule_batches, config) } @@ -268,7 +251,7 @@ mod tests { vec![Box::new(NoOp::new())], RuleExecutionStrategy::Once, )], - OptimizerConfig::new(5, false), + OptimizerConfig::new(5), ); let plan: Arc = dummy_scan_node(dummy_scan_operator(vec![Field::new("a", DataType::Int64)])).build(); @@ -315,7 +298,7 @@ mod tests { vec![Box::new(RotateProjection::new(false))], RuleExecutionStrategy::FixedPoint(Some(20)), )], - OptimizerConfig::new(20, false), + OptimizerConfig::new(20), ); let proj_exprs = vec![ col("a").add(lit(1)), @@ -350,7 +333,7 @@ mod tests { vec![Box::new(RotateProjection::new(true))], RuleExecutionStrategy::FixedPoint(Some(20)), )], - OptimizerConfig::new(20, false), + OptimizerConfig::new(20), ); let proj_exprs = vec![ col("a").add(lit(1)), @@ -401,7 +384,7 @@ mod tests { RuleExecutionStrategy::Once, ), ], - OptimizerConfig::new(20, false), + OptimizerConfig::new(20), ); let proj_exprs = vec![ col("a").add(lit(1)), diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index 7c2391ccce..632f5e3bfe 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -4,7 +4,7 @@ use common_error::DaftResult; use common_treenode::{DynTreeNode, Transformed, TreeNode}; use daft_core::prelude::*; use daft_dsl::{ - col, has_stateful_udf, + col, is_actor_pool_udf, optimization::{get_required_columns, replace_columns_with_expressions, requires_computation}, Expr, ExprRef, }; @@ -274,8 +274,11 @@ impl PushDownProjection { }) .collect_vec(); - // Construct either a new ActorPoolProject or Project, depending on whether the pruned projection still has StatefulUDFs - let new_plan = if new_actor_pool_projections.iter().any(has_stateful_udf) { + // Construct either a new ActorPoolProject or Project, depending on whether the pruned projection still has actor pool UDFs + let new_plan = if new_actor_pool_projections + .iter() + .any(|e| e.exists(is_actor_pool_udf)) + { LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( upstream_actor_pool_projection.input.clone(), new_actor_pool_projections, @@ -307,7 +310,7 @@ impl PushDownProjection { .cloned() .collect::>(); - // If all StatefulUDF expressions end up being pruned, the ActorPoolProject should essentially become + // If all actor pool UDF expressions end up being pruned, the ActorPoolProject should essentially become // a no-op passthrough projection for the rest of the columns. In this case, we should just get rid of it // altogether since it serves no purpose. let all_projections_are_just_colexprs = @@ -664,12 +667,16 @@ mod tests { use std::sync::Arc; use common_error::DaftResult; + use common_resource_request::ResourceRequest; use common_scan_info::Pushdowns; use daft_core::prelude::*; use daft_dsl::{ col, - functions::python::{RuntimePyObject, UDFRuntimeBinding}, - lit, + functions::{ + python::{MaybeInitializedUDF, PythonUDF, RuntimePyObject}, + FunctionExpr, + }, + lit, Expr, ExprRef, }; use crate::{ @@ -692,6 +699,26 @@ mod tests { ) } + fn create_actor_pool_udf(inputs: Vec) -> ExprRef { + Expr::Function { + func: FunctionExpr::Python(PythonUDF { + name: Arc::new("my-udf".to_string()), + func: MaybeInitializedUDF::Uninitialized { + inner: RuntimePyObject::new_testing_none(), + init_args: RuntimePyObject::new_testing_none(), + }, + bound_args: RuntimePyObject::new_testing_none(), + num_expressions: inputs.len(), + return_dtype: DataType::Utf8, + resource_request: Some(ResourceRequest::default_cpu()), + batch_size: None, + concurrency: Some(8), + }), + inputs, + } + .arced() + } + /// Projection merging: Ensure factored projections do not get merged. #[test] fn test_merge_does_not_unfactor() -> DaftResult<()> { @@ -903,15 +930,6 @@ mod tests { /// Projection<-ActorPoolProject prunes columns from the ActorPoolProject #[test] fn test_projection_pushdown_into_actorpoolproject() -> DaftResult<()> { - use common_resource_request::ResourceRequest; - use daft_dsl::{ - functions::{ - python::{PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, - Expr, - }; - use crate::ops::{ActorPoolProject, Project}; let scan_op = dummy_scan_operator(vec![ @@ -920,26 +938,12 @@ mod tests { Field::new("c", DataType::Int64), ]); let scan_node = dummy_scan_node(scan_op.clone()); - let mock_stateful_udf = Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - name: Arc::new("my-udf".to_string()), - stateful_partial_func: RuntimePyObject::new_testing_none(), - num_expressions: 1, - return_dtype: DataType::Utf8, - resource_request: Some(ResourceRequest::default_cpu()), - batch_size: None, - concurrency: Some(8), - init_args: None, - runtime_binding: UDFRuntimeBinding::Unbound, - })), - inputs: vec![col("c")], - } - .arced(); + let mock_udf = create_actor_pool_udf(vec![col("c")]); // Select the `udf_results` column, so the ActorPoolProject should apply column pruning to the other columns let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( scan_node.build(), - vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], + vec![col("a"), col("b"), mock_udf.alias("udf_results")], )?) .arced(); let project = LogicalPlan::Project(Project::try_new( @@ -954,7 +958,7 @@ mod tests { Pushdowns::default().with_columns(Some(Arc::new(vec!["c".to_string()]))), ) .build(), - vec![mock_stateful_udf.alias("udf_results")], + vec![mock_udf.alias("udf_results")], )?) .arced(); @@ -965,15 +969,6 @@ mod tests { /// Projection<-ActorPoolProject<-ActorPoolProject prunes columns from both ActorPoolProjects #[test] fn test_projection_pushdown_into_double_actorpoolproject() -> DaftResult<()> { - use common_resource_request::ResourceRequest; - use daft_dsl::{ - functions::{ - python::{PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, - Expr, - }; - use crate::ops::{ActorPoolProject, Project}; let scan_op = dummy_scan_operator(vec![ @@ -982,26 +977,12 @@ mod tests { Field::new("c", DataType::Int64), ]); let scan_node = dummy_scan_node(scan_op.clone()).build(); - let mock_stateful_udf = Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - name: Arc::new("my-udf".to_string()), - stateful_partial_func: RuntimePyObject::new_testing_none(), - num_expressions: 1, - return_dtype: DataType::Utf8, - resource_request: Some(ResourceRequest::default_cpu()), - batch_size: None, - concurrency: Some(8), - init_args: None, - runtime_binding: UDFRuntimeBinding::Unbound, - })), - inputs: vec![col("a")], - } - .arced(); + let mock_udf = create_actor_pool_udf(vec![col("a")]); // Select the `udf_results` column, so the ActorPoolProject should apply column pruning to the other columns let plan = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( scan_node, - vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results_0")], + vec![col("a"), col("b"), mock_udf.alias("udf_results_0")], )?) .arced(); @@ -1011,7 +992,7 @@ mod tests { col("a"), col("b"), col("udf_results_0"), - mock_stateful_udf.alias("udf_results_1"), + mock_udf.alias("udf_results_1"), ], )?) .arced(); @@ -1032,7 +1013,7 @@ mod tests { ) .build(), // col("b") is pruned - vec![mock_stateful_udf.alias("udf_results_0"), col("a")], + vec![mock_udf.alias("udf_results_0"), col("a")], )?) .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( @@ -1040,7 +1021,7 @@ mod tests { vec![ // Absorbed a non-computational expression (alias) from the Projection col("udf_results_0").alias("udf_results_0_alias"), - mock_stateful_udf.alias("udf_results_1"), + mock_udf.alias("udf_results_1"), ], )?) .arced(); @@ -1049,18 +1030,9 @@ mod tests { Ok(()) } - /// Projection<-ActorPoolProject prunes ActorPoolProject entirely if the stateful projection column is pruned + /// Projection<-ActorPoolProject prunes ActorPoolProject entirely if the actor pool UDF column is pruned #[test] fn test_projection_pushdown_into_actorpoolproject_completely_removed() -> DaftResult<()> { - use common_resource_request::ResourceRequest; - use daft_dsl::{ - functions::{ - python::{PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, - Expr, - }; - use crate::ops::{ActorPoolProject, Project}; let scan_op = dummy_scan_operator(vec![ @@ -1069,26 +1041,12 @@ mod tests { Field::new("c", DataType::Int64), ]); let scan_node = dummy_scan_node(scan_op.clone()).build(); - let mock_stateful_udf = Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { - name: Arc::new("my-udf".to_string()), - stateful_partial_func: RuntimePyObject::new_testing_none(), - num_expressions: 1, - return_dtype: DataType::Utf8, - resource_request: Some(ResourceRequest::default_cpu()), - batch_size: None, - concurrency: Some(8), - init_args: None, - runtime_binding: UDFRuntimeBinding::Unbound, - })), - inputs: vec![col("c")], - } - .arced(); + let mock_udf = create_actor_pool_udf(vec![col("c")]); // Select only col("a"), so the ActorPoolProject node is now redundant and should be removed let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( scan_node, - vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], + vec![col("a"), col("b"), mock_udf.alias("udf_results")], )?) .arced(); let project = diff --git a/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs b/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs index cfd1895488..d2b717ef20 100644 --- a/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs @@ -3,10 +3,7 @@ use std::{collections::HashSet, iter, sync::Arc}; use common_error::DaftResult; use common_treenode::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use daft_dsl::{ - functions::{ - python::{PythonUDF, StatefulPythonUDF}, - FunctionExpr, - }, + is_actor_pool_udf, optimization::{get_required_columns, requires_computation}, Expr, ExprRef, }; @@ -29,8 +26,8 @@ impl SplitActorPoolProjects { /// Implement SplitActorPoolProjects as an OptimizerRule /// * Splits PROJECT nodes into chains of (PROJECT -> ...ACTOR_POOL_PROJECTS -> PROJECT) ... -/// * Resultant PROJECT nodes will never contain any StatefulUDF expressions -/// * Each ACTOR_POOL_PROJECT node only contains a single StatefulUDF expression +/// * Resultant PROJECT nodes will never contain any actor pool UDF expressions +/// * Each ACTOR_POOL_PROJECT node only contains a single actor pool UDF expression /// /// Given a projection with 3 expressions that look like the following: /// @@ -39,22 +36,22 @@ impl SplitActorPoolProjects { /// │ ┌─────┐ ┌─────┐ ┌─────┐ │ /// │ │ E1 │ │ E2 │ │ E3 │ │ /// │ │ │ │ │ │ │ │ -/// │ StatefulUDF│ Stateless│ Stateless│ │ +/// │ UDF │ Stateless│ Stateless│ │ /// │ └──┬──┘ └─┬┬──┘ └──┬──┘ │ /// │ │ ┌──┘└──┐ │ │ /// │ ┌──▼──┐ ┌───▼─┐ ┌─▼───┐ ┌──▼────┐ │ /// │ │ E1a │ │ E2a │ │ E2b │ │col(E3)│ │ /// │ │ │ │ │ │ │ └───────┘ │ -/// │ Any │ StatefulUDF│ │ Stateless │ +/// │ Any │ UDF │ │ Stateless │ /// │ └─────┘ └─────┘ └─────┘ │ /// │ │ /// └───────────────────────────────────────────────────────────────┘ /// /// We will attempt to split this recursively into "stages". We split a given projection by truncating each expression as follows: /// -/// 1. (See E1 -> E1') Expressions with (aliased) StatefulUDFs as root nodes have all their children truncated -/// 2. (See E2 -> E2') Expressions with children StatefulUDFs have each child StatefulUDF truncated -/// 3. (See E3) Expressions without any StatefulUDFs at all are not modified +/// 1. (See E1 -> E1') Expressions with (aliased) actor pool UDFs as root nodes have all their children truncated +/// 2. (See E2 -> E2') Expressions with children actor pool UDFs have each child actor pool UDF truncated +/// 3. (See E3) Expressions without any actor pool UDFs at all are not modified /// /// The truncated children as well as any required `col` references are collected into a new set of [`remaining`] /// expressions. The new [`truncated_exprs`] make up current stage, and the [`remaining`] exprs represent the projections @@ -62,11 +59,11 @@ impl SplitActorPoolProjects { /// /// ┌───────────────────────────────────────────────────────────SPLIT: split_projection() /// │ │ -/// │ TruncateRootStatefulUDF TruncateAnyStatefulUDFChildren No-Op │ +/// │ TruncateRootActorPoolUDF TruncateAnyActorPoolUDFChildren No-Op │ /// │ ======================= ============================== ===== │ /// │ ┌─────┐ ┌─────┐ ┌─────┐ │ /// │ │ E1' │ │ E2' │ │ E3 │ │ -/// │ StatefulUDF│ Stateless│ Stateless│ │ +/// │ UDF │ Stateless│ Stateless│ │ /// │ └───┬─┘ └─┬┬──┘ └──┬──┘ │ /// │ │ ┌───┘└───┐ │ │ /// │ *--- ▼---* *--- ▼---* ┌──▼──┐ ┌──▼────┐ │ @@ -84,7 +81,7 @@ impl SplitActorPoolProjects { /// │ ┌──▼──┐ ┌───▼─┐ │ /// │ │ E1a │ │ E2a │ │ /// │ │ │ │ │ │ -/// │ Any │ StatefulUDF│ │ +/// │ Any │ UDF │ │ /// │ └─────┘ └─────┘ │ /// └─────────────────────────────────────────────────────────────────────────────────┘ /// @@ -99,14 +96,14 @@ impl SplitActorPoolProjects { /// │ ┌──▼──┐ ┌───▼─┐ │ /// │ │ E1a │ │ E2a │ │ /// │ │ │ │ │ │ -/// │ Any │ StatefulUDF│ │ +/// │ Any │ UDF │ │ /// │ └─────┘ └─────┘ │ /// │ │ /// └─┬─────────────────────────────────────────────────────────────┘ /// | /// │ Then, we link this up with our current stage, which will be resolved into a chain of logical nodes: /// | * The first PROJECT contains all the stateless expressions (E2' and E3) and passes through all required columns. -/// | * Subsequent ACTOR_POOL_PROJECT nodes each contain only one StatefulUDF, and passes through all required columns. +/// | * Subsequent ACTOR_POOL_PROJECT nodes each contain only one actor pool UDF, and passes through all required columns. /// | * The last PROJECT contains only `col` references, and correctly orders/prunes columns according to the original projection. /// | /// │ @@ -115,7 +112,7 @@ impl SplitActorPoolProjects { /// │ │ PROJECT │ │ ACTOR_POOL_PROJECT │ │ PROJECT │ /// │ │ ------- │ │ ------------------ │ ...ACTOR_PPs, │ ----------│ /// └───►│ E2', E3, col(*) ├─►│ E1', col(*) ├─ 1 per each ─►│ col("e1") │ -/// │ │ │ │ StatefulUDF │ col("e2") │ +/// │ │ │ │ actor pool UDF │ col("e2") │ /// │ │ │ │ │ col("e3") │ /// │ │ │ │ │ │ /// └─────────────────┘ └────────────────────┘ └───────────┘ @@ -128,15 +125,15 @@ impl OptimizerRule for SplitActorPoolProjects { } } -// TreeNodeRewriter that assumes the Expression tree is rooted at a StatefulUDF (or alias of a StatefulUDF) +// TreeNodeRewriter that assumes the Expression tree is rooted at a actor pool UDF (or alias of a actor pool UDF) // and its children need to be truncated + replaced with Expr::Columns -struct TruncateRootStatefulUDF { +struct TruncateRootActorPoolUDF { pub(crate) new_children: Vec, stage_idx: usize, expr_idx: usize, } -impl TruncateRootStatefulUDF { +impl TruncateRootActorPoolUDF { fn new(stage_idx: usize, expr_idx: usize) -> Self { Self { new_children: Vec::new(), @@ -146,15 +143,15 @@ impl TruncateRootStatefulUDF { } } -// TreeNodeRewriter that assumes the Expression tree has some children which are StatefulUDFs +// TreeNodeRewriter that assumes the Expression tree has some children which are actor pool UDFs // which needs to be truncated and replaced with Expr::Columns -struct TruncateAnyStatefulUDFChildren { +struct TruncateAnyActorPoolUDFChildren { pub(crate) new_children: Vec, stage_idx: usize, expr_idx: usize, } -impl TruncateAnyStatefulUDFChildren { +impl TruncateAnyActorPoolUDFChildren { fn new(stage_idx: usize, expr_idx: usize) -> Self { Self { new_children: Vec::new(), @@ -164,14 +161,14 @@ impl TruncateAnyStatefulUDFChildren { } } -/// Performs truncation of Expressions which are assumed to be rooted at a StatefulUDF expression +/// Performs truncation of Expressions which are assumed to be rooted at a actor pool UDF expression /// -/// This TreeNodeRewriter will truncate all children of the StatefulUDF expression like so: +/// This TreeNodeRewriter will truncate all children of the actor pool UDF expression like so: /// /// 1. Add an `alias(...)` to the child and push it onto `self.new_children` /// 2. Replace the child with a `col("...")` /// 3. Add any `col("...")` leaf nodes to `self.new_children` (only once per unique column name) -impl TreeNodeRewriter for TruncateRootStatefulUDF { +impl TreeNodeRewriter for TruncateRootActorPoolUDF { type Node = ExprRef; fn f_down(&mut self, node: Self::Node) -> DaftResult> { @@ -188,17 +185,15 @@ impl TreeNodeRewriter for TruncateRootStatefulUDF { } Ok(common_treenode::Transformed::no(node)) } - // Encountered stateful UDF: chop off all children and add to self.next_children - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { .. })), - inputs, - } => { + // Encountered actor pool UDF: chop off all children and add to self.next_children + _ if is_actor_pool_udf(&node) => { let mut monotonically_increasing_expr_identifier = 0; + let inputs = node.children(); let new_inputs = inputs.iter().map(|e| { if requires_computation(e.as_ref()) { // Give the new child a deterministic name let intermediate_expr_name = format!( - "__TruncateRootStatefulUDF_{}-{}-{}__", + "__TruncateRootActorPoolUDF_{}-{}-{}__", self.stage_idx, self.expr_idx, monotonically_increasing_expr_identifier ); monotonically_increasing_expr_identifier += 1; @@ -218,25 +213,22 @@ impl TreeNodeRewriter for TruncateRootStatefulUDF { } } -/// Performs truncation of Expressions which are assumed to have some subtrees which contain StatefulUDF expressions +/// Performs truncation of Expressions which are assumed to have some subtrees which contain actor pool UDF expressions /// -/// This TreeNodeRewriter will truncate StatefulUDF expressions from the tree like so: +/// This TreeNodeRewriter will truncate actor pool UDF expressions from the tree like so: /// -/// 1. Add an `alias(...)` to any StatefulUDF child and push it onto `self.new_children` +/// 1. Add an `alias(...)` to any actor pool UDF child and push it onto `self.new_children` /// 2. Replace the child with a `col("...")` /// 3. Add any `col("...")` leaf nodes to `self.new_children` (only once per unique column name) -impl TreeNodeRewriter for TruncateAnyStatefulUDFChildren { +impl TreeNodeRewriter for TruncateAnyActorPoolUDFChildren { type Node = ExprRef; fn f_down(&mut self, node: Self::Node) -> DaftResult> { match node.as_ref() { - // This rewriter should never encounter a StatefulUDF expression (they should always be truncated and replaced) - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { .. })), - .. - } => { + // This rewriter should never encounter a actor pool UDF expression (they should always be truncated and replaced) + _ if is_actor_pool_udf(&node) => { unreachable!( - "TruncateAnyStatefulUDFChildren should never run on a StatefulUDF expression" + "TruncateAnyActorPoolUDFChildren should never run on a actor pool UDF expression" ); } // If we encounter a ColumnExpr, we add it to new_children only if it hasn't already been accounted for @@ -251,37 +243,19 @@ impl TreeNodeRewriter for TruncateAnyStatefulUDFChildren { } Ok(common_treenode::Transformed::no(node)) } - // Attempt to truncate any children that are StatefulUDFs, replacing them with a Expr::Column + // Attempt to truncate any children that are actor pool UDFs, replacing them with a Expr::Column expr => { - // None of the direct children are stateful UDFs, so we keep going - if node.children().iter().all(|e| { - !matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful( - StatefulPythonUDF { .. } - )), - .. - } - ) - }) { + // None of the direct children are actor pool UDFs, so we keep going + if !node.children().iter().any(is_actor_pool_udf) { return Ok(common_treenode::Transformed::no(node)); } let mut monotonically_increasing_expr_identifier = 0; let inputs = expr.children(); let new_inputs = inputs.iter().map(|e| { - if matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful( - StatefulPythonUDF { .. } - )), - .. - } - ) { + if is_actor_pool_udf(e) { let intermediate_expr_name = format!( - "__TruncateAnyStatefulUDFChildren_{}-{}-{}__", + "__TruncateAnyActorPoolUDFChildren_{}-{}-{}__", self.stage_idx, self.expr_idx, monotonically_increasing_expr_identifier ); monotonically_increasing_expr_identifier += 1; @@ -309,27 +283,24 @@ fn split_projection( let (mut new_children_seen, mut new_children): (HashSet, Vec) = (HashSet::new(), Vec::new()); - fn _is_stateful_udf_and_should_truncate_children(expr: &ExprRef) -> bool { - let mut is_stateful_udf = true; + fn _is_actor_pool_udf_and_should_truncate_children(expr: &ExprRef) -> bool { + let mut cond = true; expr.apply(|e| match e.as_ref() { Expr::Alias(..) => Ok(TreeNodeRecursion::Continue), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { .. })), - .. - } => Ok(TreeNodeRecursion::Stop), + _ if is_actor_pool_udf(e) => Ok(TreeNodeRecursion::Stop), _ => { - is_stateful_udf = false; + cond = false; Ok(TreeNodeRecursion::Stop) } }) .unwrap(); - is_stateful_udf + cond } for (expr_idx, expr) in projection.iter().enumerate() { - // Run the TruncateRootStatefulUDF TreeNodeRewriter - if _is_stateful_udf_and_should_truncate_children(expr) { - let mut rewriter = TruncateRootStatefulUDF::new(stage_idx, expr_idx); + // Run the TruncateRootActorPoolUDF TreeNodeRewriter + if _is_actor_pool_udf_and_should_truncate_children(expr) { + let mut rewriter = TruncateRootActorPoolUDF::new(stage_idx, expr_idx); let rewritten_root = expr.clone().rewrite(&mut rewriter)?.data; truncated_exprs.push(rewritten_root); for new_child in rewriter.new_children { @@ -339,9 +310,9 @@ fn split_projection( } } - // Run the TruncateAnyStatefulUDFChildren TreeNodeRewriter - } else if has_stateful_udf(expr) { - let mut rewriter = TruncateAnyStatefulUDFChildren::new(stage_idx, expr_idx); + // Run the TruncateAnyActorPoolUDFChildren TreeNodeRewriter + } else if expr.exists(is_actor_pool_udf) { + let mut rewriter = TruncateAnyActorPoolUDFChildren::new(stage_idx, expr_idx); let rewritten_root = expr.clone().rewrite(&mut rewriter)?.data; truncated_exprs.push(rewritten_root); for new_child in rewriter.new_children { @@ -375,17 +346,17 @@ fn try_optimize_project( projection: &Project, plan: Arc, ) -> DaftResult>> { - // Add aliases to the expressions in the projection to preserve original names when splitting stateful UDFs. - // This is needed because when we split stateful UDFs, we create new names for intermediates, but we would like + // Add aliases to the expressions in the projection to preserve original names when splitting actor pool UDFs. + // This is needed because when we split actor pool UDFs, we create new names for intermediates, but we would like // to have the same expression names as the original projection. let aliased_projection_exprs = projection .projection .iter() - .map(|e| { - if has_stateful_udf(e) && !matches!(e.as_ref(), Expr::Alias(..)) { - e.alias(e.name()) + .map(|expr| { + if expr.exists(is_actor_pool_udf) && !matches!(expr.as_ref(), Expr::Alias(..)) { + expr.alias(expr.name()) } else { - e.clone() + expr.clone() } }) .collect(); @@ -402,9 +373,12 @@ fn recursive_optimize_project( ) -> DaftResult>> { // TODO: eliminate the need for recursive calls by doing a post-order traversal of the plan tree. - // Base case: no stateful UDFs at all - let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf); - if !has_stateful_udfs { + // Base case: no actor pool UDFs at all + let has_actor_pool_udfs = projection + .projection + .iter() + .any(|expr| expr.exists(is_actor_pool_udf)); + if !has_actor_pool_udfs { return Ok(Transformed::no(plan)); } @@ -452,8 +426,9 @@ fn recursive_optimize_project( }; // Start building a chain of `child -> Project -> ActorPoolProject -> ActorPoolProject -> ... -> Project` - let (stateful_stages, stateless_stages): (Vec<_>, Vec<_>) = - truncated_exprs.into_iter().partition(has_stateful_udf); + let (actor_pool_stages, stateless_stages): (Vec<_>, Vec<_>) = truncated_exprs + .into_iter() + .partition(|expr| expr.exists(is_actor_pool_udf)); // Build the new stateless Project: [...all columns that came before it, ...stateless_projections] let passthrough_columns = { @@ -481,30 +456,27 @@ fn recursive_optimize_project( let new_plan = LogicalPlan::Project(Project::try_new(new_plan_child, stateless_projection)?).arced(); - // Iteratively build ActorPoolProject nodes: [...all columns that came before it, StatefulUDF] + // Iteratively build ActorPoolProject nodes: [...all columns that came before it, actor pool UDF] let new_plan = { let mut child = new_plan; - for stateful_expr in stateful_stages { - let stateful_expr_name = stateful_expr.name().to_string(); - let stateful_projection = child + for expr in actor_pool_stages { + let expr_name = expr.name().to_string(); + let projection = child .schema() .fields .iter() .filter_map(|(name, _)| { - if name == &stateful_expr_name { + if name == &expr_name { None } else { Some(Expr::Column(name.as_str().into()).arced()) } }) - .chain(iter::once(stateful_expr)) + .chain(iter::once(expr)) .collect(); - child = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - child, - stateful_projection, - )?) - .arced(); + child = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(child, projection)?) + .arced(); } child }; @@ -524,19 +496,6 @@ fn recursive_optimize_project( Ok(Transformed::yes(final_selection_project)) } -#[inline] -fn has_stateful_udf(e: &ExprRef) -> bool { - e.exists(|e| { - matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(_)), - .. - } - ) - }) -} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -547,7 +506,7 @@ mod tests { use daft_dsl::{ col, functions::{ - python::{PythonUDF, StatefulPythonUDF, UDFRuntimeBinding}, + python::{MaybeInitializedUDF, PythonUDF, RuntimePyObject}, FunctionExpr, }, Expr, ExprRef, @@ -562,7 +521,7 @@ mod tests { LogicalPlan, }; - /// Helper that creates an optimizer with the SplitExprByStatefulUDF rule registered, optimizes + /// Helper that creates an optimizer with the SplitExprByActorPoolUDF rule registered, optimizes /// the provided plan with said optimizer, and compares the optimized plan with /// the provided expected plan. fn assert_optimized_plan_eq( @@ -576,7 +535,7 @@ mod tests { ) } - /// Helper that creates an optimizer with the SplitExprByStatefulUDF rule registered, optimizes + /// Helper that creates an optimizer with the SplitExprByActorPoolUDF rule registered, optimizes /// the provided plan with said optimizer, and compares the optimized plan with /// the provided expected plan. fn assert_optimized_plan_eq_with_projection_pushdown( @@ -593,20 +552,21 @@ mod tests { ) } - fn create_stateful_udf(inputs: Vec) -> ExprRef { + fn create_actor_pool_udf(inputs: Vec) -> ExprRef { Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + func: FunctionExpr::Python(PythonUDF { name: Arc::new("foo".to_string()), - stateful_partial_func: - daft_dsl::functions::python::RuntimePyObject::new_testing_none(), + func: MaybeInitializedUDF::Uninitialized { + inner: RuntimePyObject::new_testing_none(), + init_args: RuntimePyObject::new_testing_none(), + }, + bound_args: RuntimePyObject::new_testing_none(), num_expressions: inputs.len(), - return_dtype: DataType::Int64, + return_dtype: DataType::Utf8, resource_request: Some(create_resource_request()), batch_size: None, concurrency: Some(8), - init_args: None, - runtime_binding: UDFRuntimeBinding::Unbound, - })), + }), inputs, } .arced() @@ -617,21 +577,21 @@ mod tests { } #[test] - fn test_with_column_stateful_udf_happypath() -> DaftResult<()> { + fn test_with_column_actor_pool_udf_happypath() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); let scan_plan = dummy_scan_node(scan_op); - let stateful_project_expr = create_stateful_udf(vec![col("a")]); + let actor_pool_project_expr = create_actor_pool_udf(vec![col("a")]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request let project_plan = scan_plan - .with_columns(vec![stateful_project_expr.alias("b")])? + .with_columns(vec![actor_pool_project_expr.alias("b")])? .build(); // Project([col("a")]) --> ActorPoolProject([col("a"), foo(col("a")).alias("b")]) --> Project([col("a"), col("b")]) let expected = scan_plan.select(vec![col("a")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, - vec![col("a"), stateful_project_expr.alias("b")], + vec![col("a"), actor_pool_project_expr.alias("b")], )?) .arced(); let expected = @@ -651,20 +611,20 @@ mod tests { let scan_plan = dummy_scan_node(scan_op); let project_plan = scan_plan .with_columns(vec![ - create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]).alias("a_prime"), - create_stateful_udf(vec![create_stateful_udf(vec![col("b")])]).alias("b_prime"), + create_actor_pool_udf(vec![create_actor_pool_udf(vec![col("a")])]).alias("a_prime"), + create_actor_pool_udf(vec![create_actor_pool_udf(vec![col("b")])]).alias("b_prime"), ])? .build(); - let intermediate_column_name_0 = "__TruncateRootStatefulUDF_0-2-0__"; - let intermediate_column_name_1 = "__TruncateRootStatefulUDF_0-3-0__"; + let intermediate_column_name_0 = "__TruncateRootActorPoolUDF_0-2-0__"; + let intermediate_column_name_1 = "__TruncateRootActorPoolUDF_0-3-0__"; let expected = scan_plan.select(vec![col("a"), col("b")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), col("b"), - create_stateful_udf(vec![col("a")]).alias(intermediate_column_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_column_name_0), ], )?) .arced(); @@ -674,7 +634,7 @@ mod tests { col("a"), col("b"), col(intermediate_column_name_0), - create_stateful_udf(vec![col("b")]).alias(intermediate_column_name_1), + create_actor_pool_udf(vec![col("b")]).alias(intermediate_column_name_1), ], )?) .arced(); @@ -705,7 +665,7 @@ mod tests { col(intermediate_column_name_1), col("a"), col("b"), - create_stateful_udf(vec![col(intermediate_column_name_0)]).alias("a_prime"), + create_actor_pool_udf(vec![col(intermediate_column_name_0)]).alias("a_prime"), ], )?) .arced(); @@ -717,7 +677,7 @@ mod tests { col("a"), col("b"), col("a_prime"), - create_stateful_udf(vec![col(intermediate_column_name_1)]).alias("b_prime"), + create_actor_pool_udf(vec![col(intermediate_column_name_1)]).alias("b_prime"), ], )?) .arced(); @@ -734,22 +694,22 @@ mod tests { fn test_multiple_with_column_serial() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); let scan_plan = dummy_scan_node(scan_op); - let stacked_stateful_project_expr = - create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]); + let stacked_actor_pool_project_expr = + create_actor_pool_udf(vec![create_actor_pool_udf(vec![col("a")])]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request // Project([col("a"), foo(foo(col("a"))).alias("b")]) let project_plan = scan_plan - .with_columns(vec![stacked_stateful_project_expr.alias("b")])? + .with_columns(vec![stacked_actor_pool_project_expr.alias("b")])? .build(); - let intermediate_name = "__TruncateRootStatefulUDF_0-1-0__"; + let intermediate_name = "__TruncateRootActorPoolUDF_0-1-0__"; let expected = scan_plan.select(vec![col("a")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name), ], )?) .arced(); @@ -768,7 +728,7 @@ mod tests { vec![ col(intermediate_name), col("a"), - create_stateful_udf(vec![col(intermediate_name)]).alias("b"), + create_actor_pool_udf(vec![col(intermediate_name)]).alias("b"), ], )?) .arced(); @@ -781,7 +741,7 @@ mod tests { let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]).alias(intermediate_name), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name), col("a"), ], )?) @@ -790,7 +750,7 @@ mod tests { expected, vec![ col("a"), - create_stateful_udf(vec![col(intermediate_name)]).alias("b"), + create_actor_pool_udf(vec![col(intermediate_name)]).alias("b"), ], )?) .arced(); @@ -802,22 +762,22 @@ mod tests { fn test_multiple_with_column_serial_no_alias() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); let scan_plan = dummy_scan_node(scan_op); - let stacked_stateful_project_expr = - create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]); + let stacked_actor_pool_project_expr = + create_actor_pool_udf(vec![create_actor_pool_udf(vec![col("a")])]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request let project_plan = scan_plan - .select(vec![stacked_stateful_project_expr])? + .select(vec![stacked_actor_pool_project_expr])? .build(); - let intermediate_name = "__TruncateRootStatefulUDF_0-0-0__"; + let intermediate_name = "__TruncateRootActorPoolUDF_0-0-0__"; let expected = scan_plan.select(vec![col("a")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name), ], )?) .arced(); @@ -829,7 +789,7 @@ mod tests { expected, vec![ col(intermediate_name), - create_stateful_udf(vec![col(intermediate_name)]).alias("a"), + create_actor_pool_udf(vec![col(intermediate_name)]).alias("a"), ], )?) .arced(); @@ -838,12 +798,12 @@ mod tests { let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( scan_plan.build(), - vec![create_stateful_udf(vec![col("a")]).alias(intermediate_name)], + vec![create_actor_pool_udf(vec![col("a")]).alias(intermediate_name)], )?) .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, - vec![create_stateful_udf(vec![col(intermediate_name)]).alias("a")], + vec![create_actor_pool_udf(vec![col(intermediate_name)]).alias("a")], )?) .arced(); assert_optimized_plan_eq_with_projection_pushdown(project_plan, expected)?; @@ -858,26 +818,26 @@ mod tests { Field::new("b", DataType::Utf8), ]); let scan_plan = dummy_scan_node(scan_op); - let stacked_stateful_project_expr = create_stateful_udf(vec![ - create_stateful_udf(vec![col("a")]), - create_stateful_udf(vec![col("b")]), + let stacked_actor_pool_project_expr = create_actor_pool_udf(vec![ + create_actor_pool_udf(vec![col("a")]), + create_actor_pool_udf(vec![col("b")]), ]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request // Project([foo(foo(col("a")), foo(col("b"))).alias("c")]) let project_plan = scan_plan - .select(vec![stacked_stateful_project_expr.alias("c")])? + .select(vec![stacked_actor_pool_project_expr.alias("c")])? .build(); - let intermediate_name_0 = "__TruncateRootStatefulUDF_0-0-0__"; - let intermediate_name_1 = "__TruncateRootStatefulUDF_0-0-1__"; + let intermediate_name_0 = "__TruncateRootActorPoolUDF_0-0-0__"; + let intermediate_name_1 = "__TruncateRootActorPoolUDF_0-0-1__"; let expected = scan_plan.select(vec![col("a"), col("b")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), col("b"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); @@ -887,7 +847,7 @@ mod tests { col("a"), col("b"), col(intermediate_name_0), - create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), + create_actor_pool_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -906,7 +866,7 @@ mod tests { vec![ col(intermediate_name_0), col(intermediate_name_1), - create_stateful_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) + create_actor_pool_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) .alias("c"), ], )?) @@ -919,7 +879,7 @@ mod tests { let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), col("b"), ], )?) @@ -928,14 +888,14 @@ mod tests { expected, vec![ col(intermediate_name_0), - create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), + create_actor_pool_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) + create_actor_pool_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) .alias("c"), ], )?) @@ -951,27 +911,27 @@ mod tests { Field::new("b", DataType::Int64), ]); let scan_plan = dummy_scan_node(scan_op); - let stacked_stateful_project_expr = create_stateful_udf(vec![create_stateful_udf(vec![ - col("a"), - ]) - .add(create_stateful_udf(vec![col("b")]))]); + let stacked_actor_pool_project_expr = + create_actor_pool_udf(vec![ + create_actor_pool_udf(vec![col("a")]).add(create_actor_pool_udf(vec![col("b")])) + ]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request // Project([foo(foo(col("a")) + foo(col("b"))).alias("c")]) let project_plan = scan_plan - .select(vec![stacked_stateful_project_expr.alias("c")])? + .select(vec![stacked_actor_pool_project_expr.alias("c")])? .build(); - let intermediate_name_0 = "__TruncateAnyStatefulUDFChildren_1-0-0__"; - let intermediate_name_1 = "__TruncateAnyStatefulUDFChildren_1-0-1__"; - let intermediate_name_2 = "__TruncateRootStatefulUDF_0-0-0__"; + let intermediate_name_0 = "__TruncateAnyActorPoolUDFChildren_1-0-0__"; + let intermediate_name_1 = "__TruncateAnyActorPoolUDFChildren_1-0-1__"; + let intermediate_name_2 = "__TruncateRootActorPoolUDF_0-0-0__"; let expected = scan_plan.select(vec![col("a"), col("b")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), col("b"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); @@ -981,7 +941,7 @@ mod tests { col("a"), col("b"), col(intermediate_name_0), - create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), + create_actor_pool_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -1011,7 +971,7 @@ mod tests { expected, vec![ col(intermediate_name_2), - create_stateful_udf(vec![col(intermediate_name_2)]).alias("c"), + create_actor_pool_udf(vec![col(intermediate_name_2)]).alias("c"), ], )?) .arced(); @@ -1023,7 +983,7 @@ mod tests { let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), col("b"), ], )?) @@ -1032,7 +992,7 @@ mod tests { expected, vec![ col(intermediate_name_0), - create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), + create_actor_pool_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -1045,7 +1005,7 @@ mod tests { .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, - vec![create_stateful_udf(vec![col(intermediate_name_2)]).alias("c")], + vec![create_actor_pool_udf(vec![col(intermediate_name_2)]).alias("c")], )?) .arced(); assert_optimized_plan_eq_with_projection_pushdown(project_plan, expected)?; @@ -1056,24 +1016,24 @@ mod tests { fn test_nested_with_column_same_names() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Int64)]); let scan_plan = dummy_scan_node(scan_op); - let stacked_stateful_project_expr = - create_stateful_udf(vec![col("a").add(create_stateful_udf(vec![col("a")]))]); + let stacked_actor_pool_project_expr = + create_actor_pool_udf(vec![col("a").add(create_actor_pool_udf(vec![col("a")]))]); - // Add a Projection with StatefulUDF and resource request + // Add a Projection with actor pool UDF and resource request // Project([foo(col("a") + foo(col("a"))).alias("c")]) let project_plan = scan_plan - .select(vec![col("a"), stacked_stateful_project_expr.alias("c")])? + .select(vec![col("a"), stacked_actor_pool_project_expr.alias("c")])? .build(); - let intermediate_name_0 = "__TruncateAnyStatefulUDFChildren_1-1-0__"; - let intermediate_name_1 = "__TruncateRootStatefulUDF_0-1-0__"; + let intermediate_name_0 = "__TruncateAnyActorPoolUDFChildren_1-1-0__"; + let intermediate_name_1 = "__TruncateRootActorPoolUDF_0-1-0__"; let expected = scan_plan.build(); let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a")])?).arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); @@ -1108,7 +1068,7 @@ mod tests { vec![ col(intermediate_name_1), col("a"), - create_stateful_udf(vec![col(intermediate_name_1)]).alias("c"), + create_actor_pool_udf(vec![col(intermediate_name_1)]).alias("c"), ], )?) .arced(); @@ -1121,28 +1081,28 @@ mod tests { } #[test] - fn test_stateless_expr_with_only_some_stateful_children() -> DaftResult<()> { + fn test_stateless_expr_with_only_some_actor_pool_children() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Int64)]); let scan_plan = dummy_scan_node(scan_op); // (col("a") + col("a")) + foo(col("a")) - let stateful_project_expr = col("a") + let actor_pool_project_expr = col("a") .add(col("a")) - .add(create_stateful_udf(vec![col("a")])) + .add(create_actor_pool_udf(vec![col("a")])) .alias("result"); let project_plan = scan_plan - .select(vec![col("a"), stateful_project_expr])? + .select(vec![col("a"), actor_pool_project_expr])? .build(); - let intermediate_name_0 = "__TruncateAnyStatefulUDFChildren_0-1-0__"; - // let intermediate_name_1 = "__TruncateRootStatefulUDF_0-1-0__"; + let intermediate_name_0 = "__TruncateAnyActorPoolUDFChildren_0-1-0__"; + // let intermediate_name_1 = "__TruncateRootActorPoolUDF_0-1-0__"; let expected = scan_plan.build(); let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a")])?).arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ col("a"), - create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), + create_actor_pool_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); diff --git a/src/daft-physical-plan/src/ops/actor_pool_project.rs b/src/daft-physical-plan/src/ops/actor_pool_project.rs index 514bd6a1c4..389c398b54 100644 --- a/src/daft-physical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-physical-plan/src/ops/actor_pool_project.rs @@ -4,8 +4,9 @@ use common_error::{DaftError, DaftResult}; use common_resource_request::ResourceRequest; use common_treenode::TreeNode; use daft_dsl::{ + count_actor_pool_udfs, functions::{ - python::{get_concurrency, get_resource_request, PythonUDF, StatefulPythonUDF}, + python::{get_concurrency, get_resource_request, PythonUDF}, FunctionExpr, }, Expr, ExprRef, @@ -27,28 +28,9 @@ impl ActorPoolProject { pub(crate) fn try_new(input: PhysicalPlanRef, projection: Vec) -> DaftResult { let clustering_spec = translate_clustering_spec(input.clustering_spec(), &projection); - let num_stateful_udf_exprs: usize = projection - .iter() - .map(|expr| { - let mut num_stateful_udfs = 0; - expr.apply(|e| { - if matches!( - e.as_ref(), - Expr::Function { - func: FunctionExpr::Python(PythonUDF::Stateful(_)), - .. - } - ) { - num_stateful_udfs += 1; - } - Ok(common_treenode::TreeNodeRecursion::Continue) - }) - .unwrap(); - num_stateful_udfs - }) - .sum(); - if !num_stateful_udf_exprs == 1 { - return Err(DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 stateful UDF expression but found: {num_stateful_udf_exprs}"))); + let num_actor_pool_udfs: usize = count_actor_pool_udfs(&projection); + if !num_actor_pool_udfs == 1 { + return Err(DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 actor pool UDF expression but found: {num_actor_pool_udfs}"))); } Ok(Self { @@ -83,10 +65,11 @@ impl ActorPoolProject { proj.apply(|e| { if let Expr::Function { func: - FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { + FunctionExpr::Python(PythonUDF { name, + concurrency: Some(_), .. - })), + }), .. } = e.as_ref() { diff --git a/src/daft-table/src/ops/agg.rs b/src/daft-table/src/ops/agg.rs index 93ef8425d7..df0150257b 100644 --- a/src/daft-table/src/ops/agg.rs +++ b/src/daft-table/src/ops/agg.rs @@ -73,10 +73,17 @@ impl Table { use daft_dsl::functions::python::PythonUDF; let udf = match func { - FunctionExpr::Python(PythonUDF::Stateless(udf)) => udf, - FunctionExpr::Python(PythonUDF::Stateful(_)) => { + FunctionExpr::Python( + udf @ PythonUDF { + concurrency: None, .. + }, + ) => udf, + FunctionExpr::Python(PythonUDF { + concurrency: Some(_), + .. + }) => { return Err(DaftError::ComputeError( - "Cannot run stateful UDF in MapGroups".to_string(), + "Cannot run actor pool UDF in MapGroups".to_string(), )) } _ => { diff --git a/tests/actor_pool/test_actor_cuda_devices.py b/tests/actor_pool/test_actor_cuda_devices.py index 60509b21c0..74af05599c 100644 --- a/tests/actor_pool/test_actor_cuda_devices.py +++ b/tests/actor_pool/test_actor_cuda_devices.py @@ -8,7 +8,6 @@ import daft from daft import udf -from daft.context import get_context, set_planning_config from daft.datatype import DataType from daft.internal.gpu import cuda_visible_devices from tests.conftest import get_tests_daft_runner_name @@ -18,19 +17,6 @@ ) -@pytest.fixture(scope="module") -def enable_actor_pool(): - try: - original_config = get_context().daft_planning_config - - set_planning_config( - config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=True) - ) - yield - finally: - set_planning_config(config=original_config) - - @contextmanager def reset_runner_with_gpus(num_gpus, monkeypatch): """If current runner does not have enough GPUs, create a new runner with mocked GPU resources""" @@ -59,7 +45,7 @@ def reset_runner_with_gpus(num_gpus, monkeypatch): @pytest.mark.parametrize("concurrency", [1, 2]) @pytest.mark.parametrize("num_gpus", [1, 2]) -def test_stateful_udf_cuda_env_var(enable_actor_pool, monkeypatch, concurrency, num_gpus): +def test_actor_pool_udf_cuda_env_var(monkeypatch, concurrency, num_gpus): with reset_runner_with_gpus(concurrency * num_gpus, monkeypatch): @udf(return_dtype=DataType.string(), num_gpus=num_gpus) @@ -91,7 +77,7 @@ def __call__(self, data): assert len(all_devices) == concurrency * num_gpus -def test_stateful_udf_fractional_gpu(enable_actor_pool, monkeypatch): +def test_actor_pool_udf_fractional_gpu(monkeypatch): with reset_runner_with_gpus(1, monkeypatch): @udf(return_dtype=DataType.string(), num_gpus=0.5) @@ -121,7 +107,7 @@ def __call__(self, data): @pytest.mark.skipif(get_tests_daft_runner_name() != "py", reason="Test can only be run on PyRunner") -def test_stateful_udf_no_cuda_devices(enable_actor_pool, monkeypatch): +def test_actor_pool_udf_no_cuda_devices(monkeypatch): monkeypatch.setattr(daft.internal.gpu, "_raw_device_count_nvml", lambda: 0) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) @@ -150,7 +136,7 @@ def __call__(self, data): @pytest.mark.skipif(get_tests_daft_runner_name() != "py", reason="Test can only be run on PyRunner") -def test_stateful_udf_no_cuda_visible_device_envvar(enable_actor_pool, monkeypatch): +def test_actor_pool_udf_no_cuda_visible_device_envvar(monkeypatch): monkeypatch.setattr(daft.internal.gpu, "_raw_device_count_nvml", lambda: 1) monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) diff --git a/tests/actor_pool/test_pyactor_pool.py b/tests/actor_pool/test_pyactor_pool.py index 9f15faf97d..477fcca22f 100644 --- a/tests/actor_pool/test_pyactor_pool.py +++ b/tests/actor_pool/test_pyactor_pool.py @@ -6,7 +6,7 @@ import daft from daft import DataType, ResourceRequest from daft.context import get_context -from daft.execution.execution_step import StatefulUDFProject +from daft.execution.execution_step import ActorPoolProject from daft.expressions import ExpressionsProjection from daft.runners.partitioning import PartialPartitionMetadata from daft.runners.pyrunner import AcquiredResources, PyActorPool, PyRunner @@ -15,7 +15,7 @@ @daft.udf(return_dtype=DataType.int64()) -class MyStatefulUDF: +class MyUDF: def __init__(self): self.state = 0 @@ -25,11 +25,11 @@ def __call__(self, x): def test_pyactor_pool(): - projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + projection = ExpressionsProjection([MyUDF(daft.col("x"))]) pool = PyActorPool("my-pool", 1, [AcquiredResources(num_cpus=1, gpus={}, memory_bytes=0)], projection) initial_partition = MicroPartition.from_pydict({"x": [1, 1, 1]}) ppm = PartialPartitionMetadata(num_rows=None, size_bytes=None) - instr = StatefulUDFProject(projection=projection) + instr = ActorPoolProject(projection=projection) pool.setup() @@ -66,7 +66,7 @@ def test_pyactor_pool_not_enough_resources(): from copy import deepcopy cpu_count = multiprocessing.cpu_count() - projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + projection = ExpressionsProjection([MyUDF(daft.col("x"))]) runner = get_context().get_or_create_runner() assert isinstance(runner, PyRunner) diff --git a/tests/actor_pool/test_ray_actor_pool.py b/tests/actor_pool/test_ray_actor_pool.py index 239a5fd48b..086dcda2a9 100644 --- a/tests/actor_pool/test_ray_actor_pool.py +++ b/tests/actor_pool/test_ray_actor_pool.py @@ -3,7 +3,6 @@ import daft from daft import DataType, ResourceRequest from daft.daft import PyDaftExecutionConfig -from daft.execution.execution_step import StatefulUDFProject from daft.expressions import ExpressionsProjection from daft.runners.partitioning import PartialPartitionMetadata from daft.runners.ray_runner import RayRoundRobinActorPool @@ -11,7 +10,7 @@ @daft.udf(return_dtype=DataType.int64()) -class MyStatefulUDF: +class MyUDF: def __init__(self): self.state = 0 @@ -21,16 +20,15 @@ def __call__(self, x): def test_ray_actor_pool(): - projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + projection = ExpressionsProjection([MyUDF(daft.col("x"))]) pool = RayRoundRobinActorPool( "my-pool", 1, ResourceRequest(num_cpus=1), projection, execution_config=PyDaftExecutionConfig.from_env() ) initial_partition = ray.put(MicroPartition.from_pydict({"x": [1, 1, 1]})) ppm = PartialPartitionMetadata(num_rows=None, size_bytes=None) - instr = StatefulUDFProject(projection=projection) pool.setup() - result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + result = pool.submit(partial_metadatas=[ppm], inputs=[initial_partition]) [partial_metadata, result_data] = ray.get(result) assert len(partial_metadata) == 1 pm = partial_metadata[0] @@ -38,7 +36,7 @@ def test_ray_actor_pool(): assert pm.num_rows == 3 assert result_data.to_pydict() == {"x": [2, 2, 2]} - result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + result = pool.submit(partial_metadatas=[ppm], inputs=[initial_partition]) [partial_metadata, result_data] = ray.get(result) assert len(partial_metadata) == 1 pm = partial_metadata[0] @@ -46,7 +44,7 @@ def test_ray_actor_pool(): assert pm.num_rows == 3 assert result_data.to_pydict() == {"x": [3, 3, 3]} - result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + result = pool.submit(partial_metadatas=[ppm], inputs=[initial_partition]) [partial_metadata, result_data] = ray.get(result) assert len(partial_metadata) == 1 pm = partial_metadata[0] diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 35763dba40..13b35d3bbf 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -6,7 +6,6 @@ import daft from daft import col -from daft.context import get_context, set_planning_config from daft.datatype import DataType from daft.expressions import Expression from daft.expressions.testing import expr_structurally_equal @@ -15,18 +14,6 @@ from daft.udf import udf -@pytest.fixture(scope="function", params=[False, True]) -def actor_pool_enabled(request): - original_config = get_context().daft_planning_config - try: - set_planning_config( - config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param) - ) - yield request.param - finally: - set_planning_config(config=original_config) - - def test_udf(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -44,7 +31,8 @@ def repeat_n(data, n): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf(batch_size, actor_pool_enabled): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_class_udf(batch_size, use_actor_pool): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) @@ -55,8 +43,8 @@ def __init__(self): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) - if actor_pool_enabled: - RepeatN = RepeatN.with_concurrency(1) + if use_actor_pool: + RepeatN = RepeatN.with_concurrency(2) expr = RepeatN(col("a")) field = expr._to_field(df.schema()) @@ -68,7 +56,8 @@ def __call__(self, data): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args(batch_size, actor_pool_enabled): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_class_udf_init_args(batch_size, use_actor_pool): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) @@ -79,8 +68,8 @@ def __init__(self, initial_n: int = 2): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) - if actor_pool_enabled: - RepeatN = RepeatN.with_concurrency(1) + if use_actor_pool: + RepeatN = RepeatN.with_concurrency(2) expr = RepeatN(col("a")) field = expr._to_field(df.schema()) @@ -98,7 +87,8 @@ def __call__(self, data): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_class_udf_init_args_no_default(batch_size, use_actor_pool): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) @@ -109,10 +99,10 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) - if actor_pool_enabled: - RepeatN = RepeatN.with_concurrency(1) + if use_actor_pool: + RepeatN = RepeatN.with_concurrency(2) - with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."): + with pytest.raises(ValueError, match="Cannot call class UDF without initialization arguments."): RepeatN(col("a")) expr = RepeatN.with_init_args(initial_n=2)(col("a")) @@ -123,7 +113,8 @@ def __call__(self, data): assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} -def test_class_udf_init_args_bad_args(actor_pool_enabled): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_class_udf_init_args_bad_args(use_actor_pool): @udf(return_dtype=DataType.string()) class RepeatN: def __init__(self, initial_n): @@ -132,16 +123,15 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) - if actor_pool_enabled: - RepeatN = RepeatN.with_concurrency(1) + if use_actor_pool: + RepeatN = RepeatN.with_concurrency(2) with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): RepeatN.with_init_args(wrong=5) @pytest.mark.parametrize("concurrency", [1, 2, 4]) -@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True) -def test_stateful_udf_concurrency(concurrency, actor_pool_enabled): +def test_actor_pool_udf_concurrency(concurrency): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=1) @@ -233,11 +223,15 @@ def throw_value_err(x): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_no_args_udf_call(batch_size): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_no_args_udf_call(batch_size, use_actor_pool): @udf(return_dtype=DataType.int64(), batch_size=batch_size) def udf_no_args(): pass + if use_actor_pool: + udf_no_args = udf_no_args.with_concurrency(2) + assert isinstance(udf_no_args(), Expression) with pytest.raises(TypeError): @@ -247,18 +241,23 @@ def udf_no_args(): udf_no_args(invalid="invalid") -def test_full_udf_call(): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_full_udf_call(use_actor_pool): @udf(return_dtype=DataType.int64()) def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None): pass + if use_actor_pool: + full_udf = full_udf.with_concurrency(2) + assert isinstance(full_udf(col("x"), 1, kwarg_val=0, kwarg_ex=col("y")), Expression) with pytest.raises(TypeError): full_udf() -def test_class_udf_initialization_error(actor_pool_enabled): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_class_udf_initialization_error(use_actor_pool): df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string()) @@ -269,11 +268,11 @@ def __init__(self): def __call__(self, data): return data - if actor_pool_enabled: + if use_actor_pool: IdentityWithInitError = IdentityWithInitError.with_concurrency(1) expr = IdentityWithInitError(col("a")) - if actor_pool_enabled: + if use_actor_pool: with pytest.raises(Exception): df.select(expr).collect() else: @@ -281,11 +280,15 @@ def __call__(self, data): df.select(expr).collect() -def test_udf_equality(): +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_udf_equality(use_actor_pool): @udf(return_dtype=DataType.int64()) def udf1(x): pass + if use_actor_pool: + udf1 = udf1.with_concurrency(2) + assert expr_structurally_equal(udf1("x"), udf1("x")) assert not expr_structurally_equal(udf1("x"), udf1("y")) diff --git a/tests/test_resource_requests.py b/tests/test_resource_requests.py index 0170eac0d3..2f3d867f32 100644 --- a/tests/test_resource_requests.py +++ b/tests/test_resource_requests.py @@ -8,7 +8,6 @@ import daft from daft import udf -from daft.context import get_context, set_planning_config from daft.daft import SystemInfo from daft.expressions import col from daft.internal.gpu import cuda_visible_devices @@ -39,38 +38,38 @@ def my_udf(c): def test_partial_resource_request_overrides(): new_udf = my_udf.override_options(num_cpus=1.0) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus is None - assert new_udf.common_args.resource_request.memory_bytes is None + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus is None + assert new_udf.resource_request.memory_bytes is None new_udf = new_udf.override_options(num_gpus=8.0) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus == 8.0 - assert new_udf.common_args.resource_request.memory_bytes is None + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus == 8.0 + assert new_udf.resource_request.memory_bytes is None new_udf = new_udf.override_options(num_gpus=None) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus is None - assert new_udf.common_args.resource_request.memory_bytes is None + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus is None + assert new_udf.resource_request.memory_bytes is None new_udf = new_udf.override_options(memory_bytes=100) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus is None - assert new_udf.common_args.resource_request.memory_bytes == 100 + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus is None + assert new_udf.resource_request.memory_bytes == 100 def test_resource_request_pickle_roundtrip(): new_udf = my_udf.override_options(num_cpus=1.0) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus is None - assert new_udf.common_args.resource_request.memory_bytes is None + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus is None + assert new_udf.resource_request.memory_bytes is None assert new_udf == copy.deepcopy(new_udf) new_udf = new_udf.override_options(num_gpus=8.0) - assert new_udf.common_args.resource_request.num_cpus == 1.0 - assert new_udf.common_args.resource_request.num_gpus == 8.0 - assert new_udf.common_args.resource_request.memory_bytes is None + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus == 8.0 + assert new_udf.resource_request.memory_bytes is None assert new_udf == copy.deepcopy(new_udf) @@ -128,19 +127,6 @@ def test_requesting_too_much_memory(): ### -@pytest.fixture(scope="function", params=[True]) -def enable_actor_pool(): - try: - original_config = get_context().daft_planning_config - - set_planning_config( - config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=True) - ) - yield - finally: - set_planning_config(config=original_config) - - @udf(return_dtype=daft.DataType.int64()) def assert_resources(c, num_cpus=None, num_gpus=None, memory=None): assigned_resources = ray.get_runtime_context().get_assigned_resources() @@ -223,7 +209,7 @@ def test_with_column_folded_rayrunner(): RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" ) @pytest.mark.skipif(get_tests_daft_runner_name() not in {"ray"}, reason="requires RayRunner to be in use") -def test_with_column_rayrunner_class(enable_actor_pool): +def test_with_column_rayrunner_class(): assert_resources = AssertResourcesStateful.with_concurrency(1) df = daft.from_pydict(DATA).repartition(2) @@ -241,7 +227,7 @@ def test_with_column_rayrunner_class(enable_actor_pool): RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" ) @pytest.mark.skipif(get_tests_daft_runner_name() not in {"ray"}, reason="requires RayRunner to be in use") -def test_with_column_folded_rayrunner_class(enable_actor_pool): +def test_with_column_folded_rayrunner_class(): assert_resources = AssertResourcesStateful.with_concurrency(1) df = daft.from_pydict(DATA).repartition(2)