Skip to content

Commit

Permalink
[FEAT] Enable Actor Pool UDFs by default (#3488)
Browse files Browse the repository at this point in the history
Todo: 
- [x] fix tests
- [ ] add docs (future PR)
- [ ] add threaded concurrency (future PR)
  • Loading branch information
kevinzwang authored Dec 5, 2024
1 parent 8de0101 commit 56f5089
Show file tree
Hide file tree
Showing 36 changed files with 810 additions and 1,561 deletions.
26 changes: 7 additions & 19 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
from daft.runners.partitioning import PartitionCacheEntry
from daft.sql.sql_connection import SQLConnection
from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF
from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -1123,29 +1123,20 @@ def interval_lit(
) -> PyExpr: ...
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def stateless_udf(
def udf(
name: str,
partial_stateless_udf: PartialStatelessUDF,
inner: UninitializedUdf,
bound_args: BoundUDFArgs,
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> PyExpr: ...
def stateful_udf(
name: str,
partial_stateful_udf: PartialStatefulUDF,
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: InitArgsType,
resource_request: ResourceRequest | None,
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def check_column_name_validity(name: str, schema: PySchema): ...
def extract_partial_stateful_udf_py(
expression: PyExpr,
) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ...
def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ...
def initialize_udfs(expression: PyExpr) -> PyExpr: ...
def get_udf_names(expression: PyExpr) -> list[str]: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ...
Expand Down Expand Up @@ -1885,12 +1876,9 @@ class PyDaftPlanningConfig:
def with_config_values(
self,
default_io_config: IOConfig | None = None,
enable_actor_pool_projections: bool | None = None,
) -> PyDaftPlanningConfig: ...
@property
def default_io_config(self) -> IOConfig: ...
@property
def enable_actor_pool_projections(self) -> bool: ...

def build_type() -> str: ...
def version() -> str: ...
Expand Down
53 changes: 53 additions & 0 deletions daft/execution/actor_pool_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import logging
import multiprocessing as mp
from typing import TYPE_CHECKING

from daft.expressions import Expression, ExpressionsProjection
from daft.table import MicroPartition

if TYPE_CHECKING:
from multiprocessing.connection import Connection

from daft.daft import PyExpr, PyMicroPartition

logger = logging.getLogger(__name__)


def actor_event_loop(uninitialized_projection: ExpressionsProjection, conn: Connection) -> None:
"""
Event loop that runs in a actor process and receives MicroPartitions to evaluate with an initialized UDF projection.
Terminates once it receives None.
"""
initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection])

while True:
input: MicroPartition | None = conn.recv()
if input is None:
break

output = input.eval_expression_list(initialized_projection)
conn.send(output)


class ActorHandle:
"""Handle class for initializing, interacting with, and tearing down a single local actor process."""

def __init__(self, projection: list[PyExpr]) -> None:
self.handle_conn, actor_conn = mp.Pipe()

expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection])
self.actor_process = mp.Process(target=actor_event_loop, args=(expr_projection, actor_conn))
self.actor_process.start()

def eval_input(self, input: PyMicroPartition) -> PyMicroPartition:
self.handle_conn.send(MicroPartition._from_pymicropartition(input))
output: MicroPartition = self.handle_conn.recv()
return output._micropartition

def teardown(self) -> None:
self.handle_conn.send(None)
self.handle_conn.close()
self.actor_process.join()
4 changes: 2 additions & 2 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])


@dataclass(frozen=True)
class StatefulUDFProject(SingleOutputInstruction):
class ActorPoolProject(SingleOutputInstruction):
projection: ExpressionsProjection

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
Expand All @@ -564,7 +564,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
PartialPartitionMetadata(
num_rows=None, # UDFs can potentially change cardinality
size_bytes=None,
boundaries=None, # TODO: figure out if the stateful UDF projection changes boundaries
boundaries=None, # TODO: figure out if the actor pool UDF projection changes boundaries
)
]

Expand Down
12 changes: 5 additions & 7 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def actor_pool_context(
name: Name of the actor pool for debugging/observability
resource_request: Requested amount of resources for each actor
num_actors: Number of actors to spin up
projection: Projection to be run on the incoming data (contains Stateful UDFs as well as other stateless expressions such as aliases)
projection: Projection to be run on the incoming data (contains actor pool UDFs as well as other stateless expressions such as aliases)
"""
...

Expand All @@ -243,12 +243,10 @@ def actor_pool_project(
) -> InProgressPhysicalPlan[PartitionT]:
stage_id = next(stage_id_counter)

from daft.daft import extract_partial_stateful_udf_py
from daft.daft import get_udf_names

stateful_udf_names = "-".join(
name for expr in projection for name in extract_partial_stateful_udf_py(expr._expr).keys()
)
actor_pool_name = f"{stateful_udf_names}-stage={stage_id}"
udf_names = "-".join(name for expr in projection for name in get_udf_names(expr._expr))
actor_pool_name = f"{udf_names}-stage={stage_id}"

# Keep track of materializations of the children tasks
child_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque()
Expand Down Expand Up @@ -285,7 +283,7 @@ def actor_pool_project(
actor_pool_id=actor_pool_id,
)
.add_instruction(
instruction=execution_step.StatefulUDFProject(projection),
instruction=execution_step.ActorPoolProject(projection),
resource_request=task_resource_request,
)
.finalize_partition_task_single_output(
Expand Down
79 changes: 0 additions & 79 deletions daft/execution/stateful_actor.py

This file was deleted.

50 changes: 16 additions & 34 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import daft.daft as native
from daft import context
from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, bind_stateful_udfs
from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, initialize_udfs
from daft.daft import PyExpr as _PyExpr
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
Expand All @@ -28,13 +28,12 @@
from daft.daft import list_sort as _list_sort
from daft.daft import lit as _lit
from daft.daft import series_lit as _series_lit
from daft.daft import stateful_udf as _stateful_udf
from daft.daft import stateless_udf as _stateless_udf
from daft.daft import time_lit as _time_lit
from daft.daft import timestamp_lit as _timestamp_lit
from daft.daft import to_struct as _to_struct
from daft.daft import tokenize_decode as _tokenize_decode
from daft.daft import tokenize_encode as _tokenize_encode
from daft.daft import udf as _udf
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.datatype import DataType, TimeUnit
Expand All @@ -45,7 +44,7 @@

if TYPE_CHECKING:
from daft.io import IOConfig
from daft.udf import PartialStatefulUDF, PartialStatelessUDF
from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf
# This allows Sphinx to correctly work against our "namespaced" accessor functions by overriding @property to
# return a class instance of the namespace instead of a property object.
elif os.getenv("DAFT_SPHINX_BUILD") == "1":
Expand Down Expand Up @@ -260,39 +259,26 @@ def _to_expression(obj: object) -> Expression:
return lit(obj)

@staticmethod
def stateless_udf(
def udf(
name: builtins.str,
partial: PartialStatelessUDF,
inner: UninitializedUdf,
bound_args: BoundUDFArgs,
expressions: builtins.list[Expression],
return_dtype: DataType,
init_args: InitArgsType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateless_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, batch_size
)
)

@staticmethod
def stateful_udf(
name: builtins.str,
partial: PartialStatefulUDF,
expressions: builtins.list[Expression],
return_dtype: DataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
batch_size: int | None,
concurrency: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateful_udf(
_udf(
name,
partial,
inner,
bound_args,
[e._expr for e in expressions],
return_dtype._dtype,
resource_request,
init_args,
resource_request,
batch_size,
concurrency,
)
Expand Down Expand Up @@ -1018,7 +1004,7 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression:
Returns:
Expression: New expression after having run the function on the expression
"""
from daft.udf import CommonUDFArgs, StatelessUDF
from daft.udf import UDF

def batch_func(self_series):
return [func(x) for x in self_series.to_pylist()]
Expand All @@ -1028,14 +1014,10 @@ def batch_func(self_series):
name = name + "."
name = name + getattr(func, "__qualname__") # type: ignore[call-overload]

return StatelessUDF(
return UDF(
inner=batch_func,
name=name,
func=batch_func,
return_dtype=return_dtype,
common_args=CommonUDFArgs(
resource_request=None,
batch_size=None,
),
)(self)

def is_null(self) -> Expression:
Expand Down Expand Up @@ -1263,8 +1245,8 @@ def __reduce__(self) -> tuple:
def _input_mapping(self) -> builtins.str | None:
return self._expr._input_mapping()

def _bind_stateful_udfs(self, initialized_funcs: dict[builtins.str, Callable]) -> Expression:
return Expression._from_pyexpr(bind_stateful_udfs(self._expr, initialized_funcs))
def _initialize_udfs(self) -> Expression:
return Expression._from_pyexpr(initialize_udfs(self._expr))


SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace")
Expand Down
Loading

0 comments on commit 56f5089

Please sign in to comment.