From e01594f4cae0c39ee93c6c209d7dcbfabbfa986f Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Wed, 20 Sep 2023 15:05:42 -0700 Subject: [PATCH] Centralize pyo3 pickling around __reduce__ + bincode macro. --- daft/daft.pyi | 12 +- daft/datatype.py | 8 +- daft/expressions/expressions.py | 8 +- daft/logical/schema.py | 8 +- src/daft-core/src/count_mode.rs | 19 +-- src/daft-core/src/datatypes/image_format.rs | 2 +- src/daft-core/src/datatypes/image_mode.rs | 2 +- src/daft-core/src/python/datatype.rs | 126 ++++++++---------- src/daft-core/src/python/field.rs | 42 ++---- src/daft-core/src/python/schema.rs | 41 ++---- src/daft-core/src/utils/mod.rs | 18 ++- src/daft-dsl/src/python.rs | 38 ++---- src/daft-plan/src/join.rs | 19 +-- src/daft-plan/src/partitioning.rs | 21 +-- src/daft-plan/src/physical_plan.rs | 30 +---- src/daft-plan/src/resource_request.rs | 5 +- src/daft-plan/src/source_info/file_format.rs | 38 +----- src/daft-plan/src/source_info/file_info.rs | 24 ++-- .../src/source_info/storage_config.rs | 22 +-- 19 files changed, 130 insertions(+), 353 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index e96b2eaa6f..0dd4b4bb8d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -471,16 +471,14 @@ class PyDataType: def is_equal(self, other: Any) -> builtins.bool: ... @staticmethod def from_json(serialized: str) -> PyDataType: ... - def __setstate__(self, state: Any) -> None: ... - def __getstate__(self) -> Any: ... + def __reduce__(self) -> tuple: ... def __hash__(self) -> int: ... class PyField: def name(self) -> str: ... def dtype(self) -> PyDataType: ... def eq(self, other: PyField) -> bool: ... - def __setstate__(self, state: Any) -> None: ... - def __getstate__(self) -> Any: ... + def __reduce__(self) -> tuple: ... class PySchema: def __getitem__(self, name: str) -> PyField: ... @@ -491,8 +489,7 @@ class PySchema: def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ... @staticmethod def from_fields(fields: list[PyField]) -> PySchema: ... - def __setstate__(self, state: Any) -> None: ... - def __getstate__(self) -> Any: ... + def __reduce__(self) -> tuple: ... def __repr__(self) -> str: ... def _repr_html_(self) -> str: ... @@ -534,8 +531,7 @@ class PyExpr: def to_field(self, schema: PySchema) -> PyField: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... - def __setstate__(self, state: Any) -> None: ... - def __getstate__(self) -> Any: ... + def __reduce__(self) -> tuple: ... def is_nan(self) -> PyExpr: ... def dt_date(self) -> PyExpr: ... def dt_day(self) -> PyExpr: ... diff --git a/daft/datatype.py b/daft/datatype.py index 327d39278a..60a7da7060 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -458,12 +458,8 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> builtins.bool: return isinstance(other, DataType) and self._dtype.is_equal(other._dtype) - def __getstate__(self) -> bytes: - return self._dtype.__getstate__() - - def __setstate__(self, state: bytes) -> None: - self._dtype = PyDataType.__new__(PyDataType) - self._dtype.__setstate__(state) + def __reduce__(self) -> tuple: + return DataType._from_pydatatype, (self._dtype,) def __hash__(self) -> int: return self._dtype.__hash__() diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index ef667b7db7..2eaf5eb290 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -382,12 +382,8 @@ def _to_field(self, schema: Schema) -> Field: def __hash__(self) -> int: return self._expr.__hash__() - def __getstate__(self) -> bytes: - return self._expr.__getstate__() - - def __setstate__(self, state: bytes) -> None: - self._expr = _PyExpr.__new__(_PyExpr) - self._expr.__setstate__(state) + def __reduce__(self) -> tuple: + return Expression._from_pyexpr, (self._expr,) ### # Helper methods required by optimizer: diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 74c4778f20..2c11a9d29f 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -134,12 +134,8 @@ def union(self, other: Schema) -> Schema: return Schema._from_pyschema(self._schema.union(other._schema)) - def __getstate__(self) -> bytes: - return self._schema.__getstate__() - - def __setstate__(self, state: bytes) -> None: - self._schema = _PySchema.__new__(_PySchema) - self._schema.__setstate__(state) + def __reduce__(self) -> tuple: + return Schema._from_pyschema, (self._schema,) @classmethod def from_parquet( diff --git a/src/daft-core/src/count_mode.rs b/src/daft-core/src/count_mode.rs index b6c5383fa6..055963d962 100644 --- a/src/daft-core/src/count_mode.rs +++ b/src/daft-core/src/count_mode.rs @@ -1,8 +1,6 @@ #[cfg(feature = "python")] use pyo3::{ - exceptions::PyValueError, - prelude::*, - types::{PyBytes, PyTuple}, + exceptions::PyValueError, prelude::*, types::PyBytes, PyObject, PyTypeInfo, ToPyObject, }; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result}; @@ -20,7 +18,7 @@ use common_error::{DaftError, DaftResult}; /// | Null - Count only null values. #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] -#[cfg_attr(feature = "python", pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub enum CountMode { All = 1, Valid = 2, @@ -30,19 +28,6 @@ pub enum CountMode { #[cfg(feature = "python")] #[pymethods] impl CountMode { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy variant, to be overridden by __setstate__. - 0 => Ok(Self::All), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new JoinType, got : {}", - args.len() - ))), - } - } - /// Create a CountMode from its string representation. /// /// Args: diff --git a/src/daft-core/src/datatypes/image_format.rs b/src/daft-core/src/datatypes/image_format.rs index a5c5d8959f..098eed99af 100644 --- a/src/daft-core/src/datatypes/image_format.rs +++ b/src/daft-core/src/datatypes/image_format.rs @@ -11,7 +11,7 @@ use common_error::{DaftError, DaftResult}; /// Supported image formats for Daft's I/O layer. #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] -#[cfg_attr(feature = "python", pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub enum ImageFormat { PNG, JPEG, diff --git a/src/daft-core/src/datatypes/image_mode.rs b/src/daft-core/src/datatypes/image_mode.rs index 8738084528..a4a34ea90e 100644 --- a/src/daft-core/src/datatypes/image_mode.rs +++ b/src/daft-core/src/datatypes/image_mode.rs @@ -24,7 +24,7 @@ use common_error::{DaftError, DaftResult}; /// | RGBA32F - 32-bit floating RGB + alpha #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, FromPrimitive)] -#[cfg_attr(feature = "python", pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub enum ImageMode { L = 1, LA = 2, diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index a76946e177..09973227d9 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -1,13 +1,15 @@ use crate::{ datatypes::{DataType, Field, ImageMode, TimeUnit}, - ffi, + ffi, impl_bincode_py_state_serialization, }; use pyo3::{ class::basic::CompareOp, exceptions::PyValueError, prelude::*, - types::{PyBytes, PyDict, PyString, PyTuple}, + types::{PyBytes, PyDict, PyString}, + PyTypeInfo, }; +use serde::{Deserialize, Serialize}; #[pyclass] #[derive(Clone)] @@ -66,26 +68,14 @@ impl PyTimeUnit { } } -#[pyclass] -#[derive(Clone)] +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PyDataType { pub dtype: DataType, } #[pymethods] impl PyDataType { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - 0 => Ok(DataType::new_null().into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyDataType, got : {}", - args.len() - ))), - } - } - pub fn __repr__(&self) -> PyResult { Ok(format!("{}", self.dtype)) } @@ -283,49 +273,51 @@ impl PyDataType { Ok(DataType::Python.into()) } - pub fn to_arrow(&self, cast_tensor_type_for_ray: Option) -> PyResult { - Python::with_gil(|py| { - let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?; - let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false); - match (&self.dtype, cast_tensor_to_ray_type) { - (DataType::FixedShapeTensor(dtype, shape), false) => Ok( - if py - .import(pyo3::intern!(py, "daft.utils"))? - .getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))? - .call0()? - .extract()? - { - pyarrow - .getattr(pyo3::intern!(py, "fixed_shape_tensor"))? - .call1(( - Self { - dtype: *dtype.clone(), - } - .to_arrow(None)?, - pyo3::types::PyTuple::new(py, shape.clone()), - ))? - .to_object(py) - } else { - // Fall back to default Daft super extension representation if installed pyarrow doesn't have the - // canonical tensor extension type. - ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? - }, - ), - (DataType::FixedShapeTensor(dtype, shape), true) => Ok(py - .import(pyo3::intern!(py, "ray.data.extensions"))? - .getattr(pyo3::intern!(py, "ArrowTensorType"))? - .call1(( - pyo3::types::PyTuple::new(py, shape.clone()), - Self { - dtype: *dtype.clone(), - } - .to_arrow(None)?, - ))? - .to_object(py)), - (_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? - .getattr(py, pyo3::intern!(py, "type")), - } - }) + pub fn to_arrow( + &self, + py: Python, + cast_tensor_type_for_ray: Option, + ) -> PyResult { + let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?; + let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false); + match (&self.dtype, cast_tensor_to_ray_type) { + (DataType::FixedShapeTensor(dtype, shape), false) => Ok( + if py + .import(pyo3::intern!(py, "daft.utils"))? + .getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))? + .call0()? + .extract()? + { + pyarrow + .getattr(pyo3::intern!(py, "fixed_shape_tensor"))? + .call1(( + Self { + dtype: *dtype.clone(), + } + .to_arrow(py, None)?, + pyo3::types::PyTuple::new(py, shape.clone()), + ))? + .to_object(py) + } else { + // Fall back to default Daft super extension representation if installed pyarrow doesn't have the + // canonical tensor extension type. + ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? + }, + ), + (DataType::FixedShapeTensor(dtype, shape), true) => Ok(py + .import(pyo3::intern!(py, "ray.data.extensions"))? + .getattr(pyo3::intern!(py, "ArrowTensorType"))? + .call1(( + pyo3::types::PyTuple::new(py, shape.clone()), + Self { + dtype: *dtype.clone(), + } + .to_arrow(py, None)?, + ))? + .to_object(py)), + (_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? + .getattr(py, pyo3::intern!(py, "type")), + } } pub fn is_image(&self) -> PyResult { @@ -366,20 +358,6 @@ impl PyDataType { Ok(DataType::from_json(serialized)?.into()) } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.dtype = bincode::deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &bincode::serialize(&self.dtype).unwrap()).to_object(py)) - } - pub fn __hash__(&self) -> u64 { use std::collections::hash_map::DefaultHasher; use std::hash::Hash; @@ -390,6 +368,8 @@ impl PyDataType { } } +impl_bincode_py_state_serialization!(PyDataType); + impl From for PyDataType { fn from(value: DataType) -> Self { PyDataType { dtype: value } diff --git a/src/daft-core/src/python/field.rs b/src/daft-core/src/python/field.rs index dff13a5a86..daae3e63b2 100644 --- a/src/daft-core/src/python/field.rs +++ b/src/daft-core/src/python/field.rs @@ -1,32 +1,18 @@ -use pyo3::{ - exceptions::PyValueError, - prelude::*, - types::{PyBytes, PyTuple}, -}; +use pyo3::{prelude::*, types::PyBytes, PyTypeInfo}; +use serde::{Deserialize, Serialize}; use super::datatype::PyDataType; -use crate::datatypes::{self, DataType, Field}; +use crate::datatypes; +use crate::impl_bincode_py_state_serialization; -#[pyclass] -#[derive(Clone)] +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PyField { pub field: datatypes::Field, } #[pymethods] impl PyField { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - 0 => Ok(Field::new("null", DataType::new_null()).into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyDataType, got : {}", - args.len() - ))), - } - } - pub fn name(&self) -> PyResult { Ok(self.field.name.clone()) } @@ -38,22 +24,10 @@ impl PyField { pub fn eq(&self, other: &PyField) -> PyResult { Ok(self.field.eq(&other.field)) } - - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.field = bincode::deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &bincode::serialize(&self.field).unwrap()).to_object(py)) - } } +impl_bincode_py_state_serialization!(PyField); + impl From for PyField { fn from(field: datatypes::Field) -> Self { PyField { field } diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index 5be0331218..27e4be6335 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -1,38 +1,25 @@ use std::sync::Arc; -use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; -use pyo3::types::PyTuple; +use pyo3::PyTypeInfo; + +use serde::{Deserialize, Serialize}; use super::datatype::PyDataType; use super::field::PyField; use crate::datatypes; +use crate::impl_bincode_py_state_serialization; use crate::schema; -use crate::schema::Schema; -#[pyclass] -#[derive(Clone)] +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PySchema { pub schema: schema::SchemaRef, } #[pymethods] impl PySchema { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - 0 => Ok(Self { - schema: Schema::empty().into(), - }), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyDataType, got : {}", - args.len() - ))), - } - } - pub fn __getitem__(&self, name: &str) -> PyResult { Ok(self.schema.get_field(name)?.clone().into()) } @@ -71,20 +58,6 @@ impl PySchema { }) } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.schema = bincode::deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &bincode::serialize(&self.schema).unwrap()).to_object(py)) - } - pub fn __repr__(&self) -> PyResult { Ok(format!("{}", self.schema)) } @@ -94,6 +67,8 @@ impl PySchema { } } +impl_bincode_py_state_serialization!(PySchema); + impl From for PySchema { fn from(schema: schema::SchemaRef) -> Self { PySchema { schema } diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index c3e44e03ae..d4c644994a 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -20,13 +20,21 @@ macro_rules! impl_bincode_py_state_serialization { #[cfg(feature = "python")] #[pymethods] impl $ty { - pub fn __setstate__(&mut self, state: &PyBytes) -> PyResult<()> { - *self = bincode::deserialize(state.as_bytes()).unwrap(); - Ok(()) + pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, PyObject)> { + Ok(( + Self::type_object(py) + .getattr("_from_serialized")? + .to_object(py), + (PyBytes::new(py, &bincode::serialize(&self).unwrap()).to_object(py),) + .to_object(py), + )) } - pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyBytes> { - Ok(PyBytes::new(py, &bincode::serialize(&self).unwrap())) + #[staticmethod] + pub fn _from_serialized(py: Python, serialized: PyObject) -> PyResult { + serialized + .extract::<&PyBytes>(py) + .map(|s| bincode::deserialize(s.as_bytes()).unwrap()) } } }; diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index b89e71f570..c64afc4271 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -2,10 +2,13 @@ use std::collections::hash_map::DefaultHasher; use std::collections::HashSet; use std::hash::{Hash, Hasher}; +use serde::{Deserialize, Serialize}; + use crate::{functions, optimization, Expr}; use daft_core::{ count_mode::CountMode, datatypes::ImageFormat, + impl_bincode_py_state_serialization, python::{datatype::PyDataType, field::PyField, schema::PySchema}, }; @@ -14,7 +17,8 @@ use pyo3::{ exceptions::PyValueError, prelude::*, pyclass::CompareOp, - types::{PyBool, PyBytes, PyFloat, PyInt, PyString, PyTuple}, + types::{PyBool, PyBytes, PyFloat, PyInt, PyString}, + PyTypeInfo, }; #[pyfunction] @@ -83,8 +87,8 @@ pub fn udf( }) } -#[pyclass] -#[derive(Clone)] +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PyExpr { pub expr: crate::Expr, } @@ -96,18 +100,6 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult { #[pymethods] impl PyExpr { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - 0 => Ok(crate::null_lit().into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyExpr, got : {}", - args.len() - ))), - } - } - pub fn _input_mapping(&self) -> PyResult> { Ok(self.expr.input_mapping()) } @@ -254,20 +246,6 @@ impl PyExpr { hasher.finish() } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.expr = bincode::deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &bincode::serialize(&self.expr).unwrap()).to_object(py)) - } - pub fn is_nan(&self) -> PyResult { use functions::float::is_nan; Ok(is_nan(&self.expr).into()) @@ -382,6 +360,8 @@ impl PyExpr { } } +impl_bincode_py_state_serialization!(PyExpr); + impl From for PyExpr { fn from(value: crate::Expr) -> Self { PyExpr { expr: value } diff --git a/src/daft-plan/src/join.rs b/src/daft-plan/src/join.rs index c1290ed126..725d8a000f 100644 --- a/src/daft-plan/src/join.rs +++ b/src/daft-plan/src/join.rs @@ -7,10 +7,8 @@ use common_error::{DaftError, DaftResult}; use daft_core::impl_bincode_py_state_serialization; #[cfg(feature = "python")] use pyo3::{ - exceptions::PyValueError, - pyclass, pymethods, - types::{PyBytes, PyTuple}, - PyResult, Python, + exceptions::PyValueError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }; use serde::{Deserialize, Serialize}; @@ -27,19 +25,6 @@ pub enum JoinType { #[cfg(feature = "python")] #[pymethods] impl JoinType { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy variant, to be overridden by __setstate__. - 0 => Ok(Self::Inner), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new JoinType, got : {}", - args.len() - ))), - } - } - /// Create a JoinType from its string representation. /// /// Args: diff --git a/src/daft-plan/src/partitioning.rs b/src/daft-plan/src/partitioning.rs index 675734e529..a965c5d2d1 100644 --- a/src/daft-plan/src/partitioning.rs +++ b/src/daft-plan/src/partitioning.rs @@ -7,8 +7,8 @@ use serde::{Deserialize, Serialize}; use { daft_dsl::python::PyExpr, pyo3::{ - exceptions::PyValueError, pyclass, pyclass::CompareOp, pymethods, types::PyBytes, - types::PyTuple, PyResult, Python, + pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }, }; @@ -22,23 +22,6 @@ pub enum PartitionScheme { Unknown, } -#[cfg(feature = "python")] -#[pymethods] -impl PartitionScheme { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy variant, to be overridden by __setstate__. - 0 => Ok(Self::Unknown), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PartitionScheme, got : {}", - args.len() - ))), - } - } -} - impl_bincode_py_state_serialization!(PartitionScheme); /// Partition specification: scheme, number of partitions, partition column. diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index ea281da1f0..9c87facaa9 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -12,10 +12,8 @@ use { daft_dsl::python::PyExpr, daft_dsl::Expr, pyo3::{ - exceptions::PyValueError, - pyclass, pymethods, - types::{PyBytes, PyTuple}, - PyObject, PyRef, PyRefMut, PyResult, Python, + pyclass, pymethods, types::PyBytes, PyObject, PyRef, PyRefMut, PyResult, PyTypeInfo, + Python, ToPyObject, }, std::collections::HashMap, }; @@ -56,7 +54,7 @@ pub enum PhysicalPlan { } /// A work scheduler for physical plans. -#[cfg_attr(feature = "python", pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] #[derive(Debug, Serialize, Deserialize)] pub struct PhysicalPlanScheduler { plan: Arc, @@ -65,28 +63,6 @@ pub struct PhysicalPlanScheduler { #[cfg(feature = "python")] #[pymethods] impl PhysicalPlanScheduler { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy inner PhysicalPlan, to be overridden by __setstate__. - 0 => Ok(Arc::new(PhysicalPlan::InMemoryScan(InMemoryScan::new( - Default::default(), - InMemoryInfo::new( - daft_core::schema::Schema::new(vec![])?.into(), - "".to_string(), - args.py().None(), - ), - Default::default(), - ))) - .into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PhysicalPlanScheduler, got : {}", - args.len() - ))), - } - } - /// Converts the contained physical plan into an iterator of executable partition tasks. pub fn to_partition_tasks(&self, psets: HashMap>) -> PyResult { Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets)) diff --git a/src/daft-plan/src/resource_request.rs b/src/daft-plan/src/resource_request.rs index 95b8ccd392..d758815e48 100644 --- a/src/daft-plan/src/resource_request.rs +++ b/src/daft-plan/src/resource_request.rs @@ -2,7 +2,10 @@ use daft_core::{impl_bincode_py_state_serialization, utils::hashable_float_wrapp use std::hash::{Hash, Hasher}; #[cfg(feature = "python")] use { - pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyResult, Python}, + pyo3::{ + pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, + }, std::{cmp::max, ops::Add}, }; diff --git a/src/daft-plan/src/source_info/file_format.rs b/src/daft-plan/src/source_info/file_format.rs index e7cae3ff89..fdcb483e61 100644 --- a/src/daft-plan/src/source_info/file_format.rs +++ b/src/daft-plan/src/source_info/file_format.rs @@ -4,12 +4,8 @@ use std::sync::Arc; #[cfg(feature = "python")] use pyo3::{ - exceptions::PyValueError, - pyclass, - pyclass::CompareOp, - pymethods, - types::{PyBytes, PyTuple}, - IntoPy, PyObject, PyResult, Python, + pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }; /// Format of a file, e.g. Parquet, CSV, JSON. @@ -21,23 +17,6 @@ pub enum FileFormat { Json, } -#[cfg(feature = "python")] -#[pymethods] -impl FileFormat { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy variant, to be overridden by __setstate__. - 0 => Ok(Self::Json), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new FileFormat, got : {}", - args.len() - ))), - } - } -} - impl_bincode_py_state_serialization!(FileFormat); impl From<&FileFormatConfig> for FileFormat { @@ -144,19 +123,6 @@ pub struct PyFileFormatConfig(Arc); #[cfg(feature = "python")] #[pymethods] impl PyFileFormatConfig { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy inner FileFormatConfig, to be overridden by __setstate__. - 0 => Ok(Arc::new(FileFormatConfig::Json(JsonSourceConfig::new())).into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyFileFormatConfig, got : {}", - args.len() - ))), - } - } - /// Create a Parquet file format config. #[staticmethod] fn from_parquet_config(config: ParquetSourceConfig) -> Self { diff --git a/src/daft-plan/src/source_info/file_info.rs b/src/daft-plan/src/source_info/file_info.rs index f71e9f41a3..dd82d8b8e7 100644 --- a/src/daft-plan/src/source_info/file_info.rs +++ b/src/daft-plan/src/source_info/file_info.rs @@ -7,10 +7,8 @@ use serde::{Deserialize, Serialize}; use { daft_table::python::PyTable, pyo3::{ - exceptions::{PyKeyError, PyValueError}, - pyclass, pymethods, - types::{PyBytes, PyTuple}, - PyResult, Python, + exceptions::PyKeyError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }, }; @@ -55,16 +53,8 @@ pub struct FileInfos { #[pymethods] impl FileInfos { #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create an empty FileInfos, to be overridden by __setstate__ and/or extended with self.extend(). - 0 => Ok(Self::new_internal(vec![], vec![], vec![])), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new FileInfos, got : {}", - args.len() - ))), - } + pub fn new() -> Self { + Default::default() } #[staticmethod] @@ -189,3 +179,9 @@ impl FileInfos { ) } } + +impl Default for FileInfos { + fn default() -> Self { + Self::new_internal(vec![], vec![], vec![]) + } +} diff --git a/src/daft-plan/src/source_info/storage_config.rs b/src/daft-plan/src/source_info/storage_config.rs index 30ae0d4efd..34e1e43a10 100644 --- a/src/daft-plan/src/source_info/storage_config.rs +++ b/src/daft-plan/src/source_info/storage_config.rs @@ -9,10 +9,8 @@ use { super::py_object_serde::{deserialize_py_object_optional, serialize_py_object_optional}, common_io_config::python, pyo3::{ - exceptions::PyValueError, - pyclass, pymethods, - types::{PyBytes, PyTuple}, - IntoPy, PyObject, PyResult, Python, + pyclass, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, PyTypeInfo, Python, + ToPyObject, }, std::hash::{Hash, Hasher}, }; @@ -119,22 +117,6 @@ pub struct PyStorageConfig(Arc); #[cfg(feature = "python")] #[pymethods] impl PyStorageConfig { - #[new] - #[pyo3(signature = (*args))] - pub fn new(args: &PyTuple) -> PyResult { - match args.len() { - // Create dummy inner StorageConfig, to be overridden by __setstate__. - 0 => Ok(Arc::new(StorageConfig::Native( - NativeStorageConfig::new_internal(None).into(), - )) - .into()), - _ => Err(PyValueError::new_err(format!( - "expected no arguments to make new PyStorageConfig, got : {}", - args.len() - ))), - } - } - /// Create from a native storage config. #[staticmethod] fn native(config: NativeStorageConfig) -> Self {