Skip to content

Commit

Permalink
[FEAT] Support __init__ arguments for StatefulUDFs (#2634)
Browse files Browse the repository at this point in the history
This PR supports `__init__` arguments to be supplied to StatefulUDFs.

Users can now define arguments for `__init__` in their StatefulUDFs, and
additionally tweak the arguments at runtime by calling
`MyUDF.with_init_args(...)`

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Aug 12, 2024
1 parent 9ce458b commit d809072
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 14 deletions.
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ def stateful_udf(
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
5 changes: 4 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,12 @@ def stateful_udf(
expressions: builtins.list[Expression],
return_dtype: DataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
) -> Expression:
return Expression._from_pyexpr(
_stateful_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request)
_stateful_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, init_args
)
)

def __bool__(self) -> bool:
Expand Down
64 changes: 62 additions & 2 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class StatefulUDF(UDF):
name: str
cls: type
return_dtype: DataType
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None

def __post_init__(self):
"""Analogous to the @functools.wraps(self.cls) pattern
Expand All @@ -265,6 +266,18 @@ def __post_init__(self):
functools.update_wrapper(self, self.cls)

def __call__(self, *args, **kwargs) -> Expression:
# Validate that initialization arguments are provided if the __init__ signature indicates that there are
# parameters without defaults
init_sig = inspect.signature(self.cls.__init__) # type: ignore
if (
any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self")
and self.init_args is None
):
raise ValueError(
"Cannot call StatefulUDF without initialization arguments. Please either specify default arguments in your __init__ or provide "
"initialization arguments using `.with_init_args(...)`."
)

bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs))
expressions = list(bound_args.expressions().values())
return Expression.stateful_udf(
Expand All @@ -273,7 +286,54 @@ def __call__(self, *args, **kwargs) -> Expression:
expressions=expressions,
return_dtype=self.return_dtype,
resource_request=self.resource_request,
init_args=self.init_args,
)

def with_init_args(self, *args, **kwargs) -> StatefulUDF:
"""Replace initialization arguments for the Stateful UDF when calling __init__ at runtime
on each instance of the UDF.
Example:
>>> import daft
>>>
>>> @daft.udf(return_dtype=daft.DataType.string())
... class MyInitializedClass:
... def __init__(self, text=" world"):
... self.text = text
...
... def __call__(self, data):
... return [x + self.text for x in data.to_pylist()]
>>>
>>> # Create a customized version of MyInitializedClass by overriding the init args
>>> MyInitializedClass_CustomInitArgs = MyInitializedClass.with_init_args(text=" my old friend")
>>>
>>> df = daft.from_pydict({"foo": ["hello", "hello", "hello"]})
>>> df = df.with_column("bar_world", MyInitializedClass(df["foo"]))
>>> df = df.with_column("bar_custom", MyInitializedClass_CustomInitArgs(df["foo"]))
>>> df.show()
╭───────┬─────────────┬─────────────────────╮
│ foo ┆ bar_world ┆ bar_custom │
│ --- ┆ --- ┆ --- │
│ Utf8 ┆ Utf8 ┆ Utf8 │
╞═══════╪═════════════╪═════════════════════╡
│ hello ┆ hello world ┆ hello my old friend │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ hello ┆ hello world ┆ hello my old friend │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ hello ┆ hello world ┆ hello my old friend │
╰───────┴─────────────┴─────────────────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
init_sig = inspect.signature(self.cls.__init__) # type: ignore
init_sig.bind(
# Placeholder for `self`
None,
*args,
**kwargs,
)
return dataclasses.replace(self, init_args=(args, kwargs))

def bind_func(self, *args, **kwargs) -> inspect.BoundArguments:
sig = inspect.signature(self.cls.__call__)
Expand All @@ -296,7 +356,7 @@ def udf(
num_cpus: float | None = None,
num_gpus: float | None = None,
memory_bytes: int | None = None,
) -> Callable[[UserProvidedPythonFunction | type], UDF]:
) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]:
"""Decorator to convert a Python function into a UDF
UDFs allow users to run arbitrary Python code on the outputs of Expressions.
Expand Down Expand Up @@ -408,7 +468,7 @@ def udf(
Callable[[UserProvidedPythonFunction], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions
"""

def _udf(f: UserProvidedPythonFunction | type) -> UDF:
def _udf(f: UserProvidedPythonFunction | type) -> StatelessUDF | StatefulUDF:
# Grab a name for the UDF. It **should** be unique.
name = getattr(f, "__module__", "") # type: ignore[call-overload]
if name:
Expand Down
14 changes: 9 additions & 5 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(feature = "python")]
mod partial_udf;
mod pyobj_serde;
mod udf;

use std::sync::Arc;
Expand Down Expand Up @@ -35,7 +35,7 @@ impl PythonUDF {
pub struct StatelessPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
partial_func: partial_udf::PyPartialUDF,
partial_func: pyobj_serde::PyObjectWrapper,
num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
Expand All @@ -45,10 +45,12 @@ pub struct StatelessPythonUDF {
pub struct StatefulPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
pub stateful_partial_func: partial_udf::PyPartialUDF,
pub stateful_partial_func: pyobj_serde::PyObjectWrapper,
pub num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
#[cfg(feature = "python")]
pub init_args: Option<pyobj_serde::PyObjectWrapper>,
}

#[cfg(feature = "python")]
Expand All @@ -62,7 +64,7 @@ pub fn stateless_udf(
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
partial_func: partial_udf::PyPartialUDF(py_partial_stateless_udf),
partial_func: pyobj_serde::PyObjectWrapper(py_partial_stateless_udf),
num_expressions: expressions.len(),
return_dtype,
resource_request,
Expand Down Expand Up @@ -96,14 +98,16 @@ pub fn stateful_udf(
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
init_args: Option<pyo3::PyObject>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
stateful_partial_func: partial_udf::PyPartialUDF(py_stateful_partial_func),
stateful_partial_func: pyobj_serde::PyObjectWrapper(py_stateful_partial_func),
num_expressions: expressions.len(),
return_dtype,
resource_request,
init_args: init_args.map(pyobj_serde::PyObjectWrapper),
})),
inputs: expressions.into(),
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ use serde::{Deserialize, Serialize};

// This is a Rust wrapper on top of a Python PartialStatelessUDF or PartialStatefulUDF to make it serde-able and hashable
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyPartialUDF(
pub struct PyObjectWrapper(
#[serde(
serialize_with = "serialize_py_object",
deserialize_with = "deserialize_py_object"
)]
pub PyObject,
);

impl PartialEq for PyPartialUDF {
impl PartialEq for PyObjectWrapper {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| self.0.as_ref(py).eq(other.0.as_ref(py)).unwrap())
}
}

impl Eq for PyPartialUDF {}
impl Eq for PyObjectWrapper {}

impl Hash for PyPartialUDF {
impl Hash for PyObjectWrapper {
fn hash<H: Hasher>(&self, state: &mut H) {
let py_obj_hash = Python::with_gil(|py| self.0.as_ref(py).hash());
match py_obj_hash {
Expand Down
27 changes: 25 additions & 2 deletions src/daft-dsl/src/functions/python/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ impl FunctionEvaluator for StatefulPythonUDF {

#[cfg(feature = "python")]
fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
use pyo3::Python;
use pyo3::{
types::{PyDict, PyTuple},
Python,
};

if inputs.len() != self.num_expressions {
return Err(DaftError::SchemaMismatch(format!(
Expand All @@ -175,7 +178,27 @@ impl FunctionEvaluator for StatefulPythonUDF {
// HACK: This is the naive initialization of the class. It is performed once-per-evaluate which is not ideal.
// Ideally we need to allow evaluate to somehow take in the **initialized** Python class that is provided by the Actor.
// Either that, or the code-path to evaluate a StatefulUDF should bypass `evaluate` entirely and do its own thing.
let func = func.call0(py)?;
let func = match &self.init_args {
None => func.call0(py)?,
Some(init_args) => {
let init_args = init_args
.0
.as_ref(py)
.downcast::<PyTuple>()
.expect("init_args should be a Python tuple");
let (args, kwargs) = (
init_args
.get_item(0)?
.downcast::<PyTuple>()
.expect("init_args[0] should be a tuple of *args"),
init_args
.get_item(1)?
.downcast::<PyDict>()
.expect("init_args[1] should be a dict of **kwargs"),
);
func.call(py, args, Some(kwargs))?
}
};

run_udf(py, inputs, func, bound_args, &self.return_dtype)
})
Expand Down
3 changes: 3 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,23 @@ pub fn stateful_udf(
expressions: Vec<PyExpr>,
return_dtype: PyDataType,
resource_request: Option<ResourceRequest>,
init_args: Option<&PyAny>,
) -> PyResult<PyExpr> {
use crate::functions::python::stateful_udf;

// Convert &PyAny values to a GIL-independent reference to Python objects (PyObject) so that we can store them in our Rust Expr enums
// See: https://pyo3.rs/v0.18.2/types#pyt-and-pyobject
let partial_stateful_udf = partial_stateful_udf.to_object(py);
let expressions_map: Vec<ExprRef> = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect();
let init_args = init_args.map(|args| args.to_object(py));
Ok(PyExpr {
expr: stateful_udf(
name,
partial_stateful_udf,
&expressions_map,
return_dtype.dtype,
resource_request,
init_args,
)?
.into(),
})
Expand Down
61 changes: 61 additions & 0 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,67 @@ def __call__(self, data):
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_class_udf_init_args():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n: int = 2):
self.n = initial_n

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

expr = RepeatN(col("a"))
field = expr._to_field(table.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}

expr = RepeatN.with_init_args(initial_n=3)(col("a"))
field = expr._to_field(table.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]}


def test_class_udf_init_args_no_default():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n):
self.n = initial_n

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."):
RepeatN(col("a"))

expr = RepeatN.with_init_args(initial_n=2)(col("a"))
field = expr._to_field(table.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_class_udf_init_args_bad_args():
@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n):
self.n = initial_n

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"):
RepeatN.with_init_args(wrong=5)


def test_udf_kwargs():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand Down

0 comments on commit d809072

Please sign in to comment.