From d8090722d7fc9bd1f27273797d7c201b39cc3ed7 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:55:49 -0700 Subject: [PATCH] [FEAT] Support __init__ arguments for StatefulUDFs (#2634) This PR supports `__init__` arguments to be supplied to StatefulUDFs. Users can now define arguments for `__init__` in their StatefulUDFs, and additionally tweak the arguments at runtime by calling `MyUDF.with_init_args(...)` --------- Co-authored-by: Jay Chia --- daft/daft.pyi | 1 + daft/expressions/expressions.py | 5 +- daft/udf.py | 64 ++++++++++++++++++- src/daft-dsl/src/functions/python/mod.rs | 14 ++-- .../python/{partial_udf.rs => pyobj_serde.rs} | 8 +-- src/daft-dsl/src/functions/python/udf.rs | 27 +++++++- src/daft-dsl/src/python.rs | 3 + tests/expressions/test_udf.py | 61 ++++++++++++++++++ 8 files changed, 169 insertions(+), 14 deletions(-) rename src/daft-dsl/src/functions/python/{partial_udf.rs => pyobj_serde.rs} (89%) diff --git a/daft/daft.pyi b/daft/daft.pyi index 4a03335464..7db51df434 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1195,6 +1195,7 @@ def stateful_udf( expressions: list[PyExpr], return_dtype: PyDataType, resource_request: ResourceRequest | None, + init_args: tuple[tuple[Any, ...], dict[str, Any]] | None, ) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index f223f2b485..6ac9e342d0 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -249,9 +249,12 @@ def stateful_udf( expressions: builtins.list[Expression], return_dtype: DataType, resource_request: ResourceRequest | None, + init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None, ) -> Expression: return Expression._from_pyexpr( - _stateful_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request) + _stateful_udf( + name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, init_args + ) ) def __bool__(self) -> bool: diff --git a/daft/udf.py b/daft/udf.py index adecfcfdb2..e1f39b59a8 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -255,6 +255,7 @@ class StatefulUDF(UDF): name: str cls: type return_dtype: DataType + init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None def __post_init__(self): """Analogous to the @functools.wraps(self.cls) pattern @@ -265,6 +266,18 @@ def __post_init__(self): functools.update_wrapper(self, self.cls) def __call__(self, *args, **kwargs) -> Expression: + # Validate that initialization arguments are provided if the __init__ signature indicates that there are + # parameters without defaults + init_sig = inspect.signature(self.cls.__init__) # type: ignore + if ( + any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self") + and self.init_args is None + ): + raise ValueError( + "Cannot call StatefulUDF without initialization arguments. Please either specify default arguments in your __init__ or provide " + "initialization arguments using `.with_init_args(...)`." + ) + bound_args = BoundUDFArgs(self.bind_func(*args, **kwargs)) expressions = list(bound_args.expressions().values()) return Expression.stateful_udf( @@ -273,7 +286,54 @@ def __call__(self, *args, **kwargs) -> Expression: expressions=expressions, return_dtype=self.return_dtype, resource_request=self.resource_request, + init_args=self.init_args, + ) + + def with_init_args(self, *args, **kwargs) -> StatefulUDF: + """Replace initialization arguments for the Stateful UDF when calling __init__ at runtime + on each instance of the UDF. + + Example: + + >>> import daft + >>> + >>> @daft.udf(return_dtype=daft.DataType.string()) + ... class MyInitializedClass: + ... def __init__(self, text=" world"): + ... self.text = text + ... + ... def __call__(self, data): + ... return [x + self.text for x in data.to_pylist()] + >>> + >>> # Create a customized version of MyInitializedClass by overriding the init args + >>> MyInitializedClass_CustomInitArgs = MyInitializedClass.with_init_args(text=" my old friend") + >>> + >>> df = daft.from_pydict({"foo": ["hello", "hello", "hello"]}) + >>> df = df.with_column("bar_world", MyInitializedClass(df["foo"])) + >>> df = df.with_column("bar_custom", MyInitializedClass_CustomInitArgs(df["foo"])) + >>> df.show() + ╭───────┬─────────────┬─────────────────────╮ + │ foo ┆ bar_world ┆ bar_custom │ + │ --- ┆ --- ┆ --- │ + │ Utf8 ┆ Utf8 ┆ Utf8 │ + ╞═══════╪═════════════╪═════════════════════╡ + │ hello ┆ hello world ┆ hello my old friend │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ hello ┆ hello world ┆ hello my old friend │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ hello ┆ hello world ┆ hello my old friend │ + ╰───────┴─────────────┴─────────────────────╯ + + (Showing first 3 of 3 rows) + """ + init_sig = inspect.signature(self.cls.__init__) # type: ignore + init_sig.bind( + # Placeholder for `self` + None, + *args, + **kwargs, ) + return dataclasses.replace(self, init_args=(args, kwargs)) def bind_func(self, *args, **kwargs) -> inspect.BoundArguments: sig = inspect.signature(self.cls.__call__) @@ -296,7 +356,7 @@ def udf( num_cpus: float | None = None, num_gpus: float | None = None, memory_bytes: int | None = None, -) -> Callable[[UserProvidedPythonFunction | type], UDF]: +) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]: """Decorator to convert a Python function into a UDF UDFs allow users to run arbitrary Python code on the outputs of Expressions. @@ -408,7 +468,7 @@ def udf( Callable[[UserProvidedPythonFunction], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions """ - def _udf(f: UserProvidedPythonFunction | type) -> UDF: + def _udf(f: UserProvidedPythonFunction | type) -> StatelessUDF | StatefulUDF: # Grab a name for the UDF. It **should** be unique. name = getattr(f, "__module__", "") # type: ignore[call-overload] if name: diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index d62468f408..1ace104d1a 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -1,5 +1,5 @@ #[cfg(feature = "python")] -mod partial_udf; +mod pyobj_serde; mod udf; use std::sync::Arc; @@ -35,7 +35,7 @@ impl PythonUDF { pub struct StatelessPythonUDF { pub name: Arc, #[cfg(feature = "python")] - partial_func: partial_udf::PyPartialUDF, + partial_func: pyobj_serde::PyObjectWrapper, num_expressions: usize, pub return_dtype: DataType, pub resource_request: Option, @@ -45,10 +45,12 @@ pub struct StatelessPythonUDF { pub struct StatefulPythonUDF { pub name: Arc, #[cfg(feature = "python")] - pub stateful_partial_func: partial_udf::PyPartialUDF, + pub stateful_partial_func: pyobj_serde::PyObjectWrapper, pub num_expressions: usize, pub return_dtype: DataType, pub resource_request: Option, + #[cfg(feature = "python")] + pub init_args: Option, } #[cfg(feature = "python")] @@ -62,7 +64,7 @@ pub fn stateless_udf( Ok(Expr::Function { func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF { name: name.to_string().into(), - partial_func: partial_udf::PyPartialUDF(py_partial_stateless_udf), + partial_func: pyobj_serde::PyObjectWrapper(py_partial_stateless_udf), num_expressions: expressions.len(), return_dtype, resource_request, @@ -96,14 +98,16 @@ pub fn stateful_udf( expressions: &[ExprRef], return_dtype: DataType, resource_request: Option, + init_args: Option, ) -> DaftResult { Ok(Expr::Function { func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { name: name.to_string().into(), - stateful_partial_func: partial_udf::PyPartialUDF(py_stateful_partial_func), + stateful_partial_func: pyobj_serde::PyObjectWrapper(py_stateful_partial_func), num_expressions: expressions.len(), return_dtype, resource_request, + init_args: init_args.map(pyobj_serde::PyObjectWrapper), })), inputs: expressions.into(), }) diff --git a/src/daft-dsl/src/functions/python/partial_udf.rs b/src/daft-dsl/src/functions/python/pyobj_serde.rs similarity index 89% rename from src/daft-dsl/src/functions/python/partial_udf.rs rename to src/daft-dsl/src/functions/python/pyobj_serde.rs index c902a5b381..fb03d7a0ce 100644 --- a/src/daft-dsl/src/functions/python/partial_udf.rs +++ b/src/daft-dsl/src/functions/python/pyobj_serde.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; // This is a Rust wrapper on top of a Python PartialStatelessUDF or PartialStatefulUDF to make it serde-able and hashable #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PyPartialUDF( +pub struct PyObjectWrapper( #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" @@ -14,15 +14,15 @@ pub struct PyPartialUDF( pub PyObject, ); -impl PartialEq for PyPartialUDF { +impl PartialEq for PyObjectWrapper { fn eq(&self, other: &Self) -> bool { Python::with_gil(|py| self.0.as_ref(py).eq(other.0.as_ref(py)).unwrap()) } } -impl Eq for PyPartialUDF {} +impl Eq for PyObjectWrapper {} -impl Hash for PyPartialUDF { +impl Hash for PyObjectWrapper { fn hash(&self, state: &mut H) { let py_obj_hash = Python::with_gil(|py| self.0.as_ref(py).hash()); match py_obj_hash { diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 15d29f4c07..0b9f947e8c 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -151,7 +151,10 @@ impl FunctionEvaluator for StatefulPythonUDF { #[cfg(feature = "python")] fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - use pyo3::Python; + use pyo3::{ + types::{PyDict, PyTuple}, + Python, + }; if inputs.len() != self.num_expressions { return Err(DaftError::SchemaMismatch(format!( @@ -175,7 +178,27 @@ impl FunctionEvaluator for StatefulPythonUDF { // HACK: This is the naive initialization of the class. It is performed once-per-evaluate which is not ideal. // Ideally we need to allow evaluate to somehow take in the **initialized** Python class that is provided by the Actor. // Either that, or the code-path to evaluate a StatefulUDF should bypass `evaluate` entirely and do its own thing. - let func = func.call0(py)?; + let func = match &self.init_args { + None => func.call0(py)?, + Some(init_args) => { + let init_args = init_args + .0 + .as_ref(py) + .downcast::() + .expect("init_args should be a Python tuple"); + let (args, kwargs) = ( + init_args + .get_item(0)? + .downcast::() + .expect("init_args[0] should be a tuple of *args"), + init_args + .get_item(1)? + .downcast::() + .expect("init_args[1] should be a dict of **kwargs"), + ); + func.call(py, args, Some(kwargs))? + } + }; run_udf(py, inputs, func, bound_args, &self.return_dtype) }) diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4d1e91a90..2fafc4194b 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -186,6 +186,7 @@ pub fn stateful_udf( expressions: Vec, return_dtype: PyDataType, resource_request: Option, + init_args: Option<&PyAny>, ) -> PyResult { use crate::functions::python::stateful_udf; @@ -193,6 +194,7 @@ pub fn stateful_udf( // See: https://pyo3.rs/v0.18.2/types#pyt-and-pyobject let partial_stateful_udf = partial_stateful_udf.to_object(py); let expressions_map: Vec = expressions.into_iter().map(|pyexpr| pyexpr.expr).collect(); + let init_args = init_args.map(|args| args.to_object(py)); Ok(PyExpr { expr: stateful_udf( name, @@ -200,6 +202,7 @@ pub fn stateful_udf( &expressions_map, return_dtype.dtype, resource_request, + init_args, )? .into(), }) diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 3451100584..2bd1f12d9b 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -49,6 +49,67 @@ def __call__(self, data): assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} +def test_class_udf_init_args(): + table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) + + @udf(return_dtype=DataType.string()) + class RepeatN: + def __init__(self, initial_n: int = 2): + self.n = initial_n + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + expr = RepeatN(col("a")) + field = expr._to_field(table.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} + + expr = RepeatN.with_init_args(initial_n=3)(col("a")) + field = expr._to_field(table.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]} + + +def test_class_udf_init_args_no_default(): + table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) + + @udf(return_dtype=DataType.string()) + class RepeatN: + def __init__(self, initial_n): + self.n = initial_n + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."): + RepeatN(col("a")) + + expr = RepeatN.with_init_args(initial_n=2)(col("a")) + field = expr._to_field(table.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} + + +def test_class_udf_init_args_bad_args(): + @udf(return_dtype=DataType.string()) + class RepeatN: + def __init__(self, initial_n): + self.n = initial_n + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): + RepeatN.with_init_args(wrong=5) + + def test_udf_kwargs(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})