Skip to content

Commit

Permalink
Centralize pyo3 pickling around __reduce__ + bincode macro.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Sep 20, 2023
1 parent f6370e5 commit e01594f
Show file tree
Hide file tree
Showing 19 changed files with 130 additions and 353 deletions.
12 changes: 4 additions & 8 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,14 @@ class PyDataType:
def is_equal(self, other: Any) -> builtins.bool: ...
@staticmethod
def from_json(serialized: str) -> PyDataType: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def __hash__(self) -> int: ...

class PyField:
def name(self) -> str: ...
def dtype(self) -> PyDataType: ...
def eq(self, other: PyField) -> bool: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...

class PySchema:
def __getitem__(self, name: str) -> PyField: ...
Expand All @@ -491,8 +489,7 @@ class PySchema:
def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ...
@staticmethod
def from_fields(fields: list[PyField]) -> PySchema: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def __repr__(self) -> str: ...
def _repr_html_(self) -> str: ...

Expand Down Expand Up @@ -534,8 +531,7 @@ class PyExpr:
def to_field(self, schema: PySchema) -> PyField: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def is_nan(self) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
Expand Down
8 changes: 2 additions & 6 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,8 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> builtins.bool:
return isinstance(other, DataType) and self._dtype.is_equal(other._dtype)

def __getstate__(self) -> bytes:
return self._dtype.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._dtype = PyDataType.__new__(PyDataType)
self._dtype.__setstate__(state)
def __reduce__(self) -> tuple:
return DataType._from_pydatatype, (self._dtype,)

def __hash__(self) -> int:
return self._dtype.__hash__()
Expand Down
8 changes: 2 additions & 6 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,8 @@ def _to_field(self, schema: Schema) -> Field:
def __hash__(self) -> int:
return self._expr.__hash__()

def __getstate__(self) -> bytes:
return self._expr.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._expr = _PyExpr.__new__(_PyExpr)
self._expr.__setstate__(state)
def __reduce__(self) -> tuple:
return Expression._from_pyexpr, (self._expr,)

###
# Helper methods required by optimizer:
Expand Down
8 changes: 2 additions & 6 deletions daft/logical/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ def union(self, other: Schema) -> Schema:

return Schema._from_pyschema(self._schema.union(other._schema))

def __getstate__(self) -> bytes:
return self._schema.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._schema = _PySchema.__new__(_PySchema)
self._schema.__setstate__(state)
def __reduce__(self) -> tuple:
return Schema._from_pyschema, (self._schema,)

@classmethod
def from_parquet(
Expand Down
19 changes: 2 additions & 17 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#[cfg(feature = "python")]
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
exceptions::PyValueError, prelude::*, types::PyBytes, PyObject, PyTypeInfo, ToPyObject,
};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
Expand All @@ -20,7 +18,7 @@ use common_error::{DaftError, DaftResult};
/// | Null - Count only null values.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum CountMode {
All = 1,
Valid = 2,
Expand All @@ -30,19 +28,6 @@ pub enum CountMode {
#[cfg(feature = "python")]
#[pymethods]
impl CountMode {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
// Create dummy variant, to be overridden by __setstate__.
0 => Ok(Self::All),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new JoinType, got : {}",
args.len()
))),
}
}

/// Create a CountMode from its string representation.
///
/// Args:
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/image_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use common_error::{DaftError, DaftResult};
/// Supported image formats for Daft's I/O layer.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum ImageFormat {
PNG,
JPEG,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/image_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use common_error::{DaftError, DaftResult};
/// | RGBA32F - 32-bit floating RGB + alpha
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, FromPrimitive)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum ImageMode {
L = 1,
LA = 2,
Expand Down
126 changes: 53 additions & 73 deletions src/daft-core/src/python/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::{
datatypes::{DataType, Field, ImageMode, TimeUnit},
ffi,
ffi, impl_bincode_py_state_serialization,
};
use pyo3::{
class::basic::CompareOp,
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyDict, PyString, PyTuple},
types::{PyBytes, PyDict, PyString},
PyTypeInfo,
};
use serde::{Deserialize, Serialize};

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -66,26 +68,14 @@ impl PyTimeUnit {
}
}

#[pyclass]
#[derive(Clone)]
#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyDataType {
pub dtype: DataType,
}

#[pymethods]
impl PyDataType {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(DataType::new_null().into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.dtype))
}
Expand Down Expand Up @@ -283,49 +273,51 @@ impl PyDataType {
Ok(DataType::Python.into())
}

pub fn to_arrow(&self, cast_tensor_type_for_ray: Option<bool>) -> PyResult<PyObject> {
Python::with_gil(|py| {
let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?;
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false);
match (&self.dtype, cast_tensor_to_ray_type) {
(DataType::FixedShapeTensor(dtype, shape), false) => Ok(
if py
.import(pyo3::intern!(py, "daft.utils"))?
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))?
.call0()?
.extract()?
{
pyarrow
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))?
.call1((
Self {
dtype: *dtype.clone(),
}
.to_arrow(None)?,
pyo3::types::PyTuple::new(py, shape.clone()),
))?
.to_object(py)
} else {
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the
// canonical tensor extension type.
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
},
),
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py
.import(pyo3::intern!(py, "ray.data.extensions"))?
.getattr(pyo3::intern!(py, "ArrowTensorType"))?
.call1((
pyo3::types::PyTuple::new(py, shape.clone()),
Self {
dtype: *dtype.clone(),
}
.to_arrow(None)?,
))?
.to_object(py)),
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
.getattr(py, pyo3::intern!(py, "type")),
}
})
pub fn to_arrow(
&self,
py: Python,
cast_tensor_type_for_ray: Option<bool>,
) -> PyResult<PyObject> {
let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?;
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false);
match (&self.dtype, cast_tensor_to_ray_type) {
(DataType::FixedShapeTensor(dtype, shape), false) => Ok(
if py
.import(pyo3::intern!(py, "daft.utils"))?
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))?
.call0()?
.extract()?
{
pyarrow
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))?
.call1((
Self {
dtype: *dtype.clone(),
}
.to_arrow(py, None)?,
pyo3::types::PyTuple::new(py, shape.clone()),
))?
.to_object(py)
} else {
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the
// canonical tensor extension type.
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
},
),
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py
.import(pyo3::intern!(py, "ray.data.extensions"))?
.getattr(pyo3::intern!(py, "ArrowTensorType"))?
.call1((
pyo3::types::PyTuple::new(py, shape.clone()),
Self {
dtype: *dtype.clone(),
}
.to_arrow(py, None)?,
))?
.to_object(py)),
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
.getattr(py, pyo3::intern!(py, "type")),
}
}

pub fn is_image(&self) -> PyResult<bool> {
Expand Down Expand Up @@ -366,20 +358,6 @@ impl PyDataType {
Ok(DataType::from_json(serialized)?.into())
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.dtype = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.dtype).unwrap()).to_object(py))
}

pub fn __hash__(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
Expand All @@ -390,6 +368,8 @@ impl PyDataType {
}
}

impl_bincode_py_state_serialization!(PyDataType);

impl From<DataType> for PyDataType {
fn from(value: DataType) -> Self {
PyDataType { dtype: value }
Expand Down
42 changes: 8 additions & 34 deletions src/daft-core/src/python/field.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
};
use pyo3::{prelude::*, types::PyBytes, PyTypeInfo};
use serde::{Deserialize, Serialize};

use super::datatype::PyDataType;
use crate::datatypes::{self, DataType, Field};
use crate::datatypes;
use crate::impl_bincode_py_state_serialization;

#[pyclass]
#[derive(Clone)]
#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyField {
pub field: datatypes::Field,
}

#[pymethods]
impl PyField {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(Field::new("null", DataType::new_null()).into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn name(&self) -> PyResult<String> {
Ok(self.field.name.clone())
}
Expand All @@ -38,22 +24,10 @@ impl PyField {
pub fn eq(&self, other: &PyField) -> PyResult<bool> {
Ok(self.field.eq(&other.field))
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.field = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.field).unwrap()).to_object(py))
}
}

impl_bincode_py_state_serialization!(PyField);

impl From<datatypes::Field> for PyField {
fn from(field: datatypes::Field) -> Self {
PyField { field }
Expand Down
Loading

0 comments on commit e01594f

Please sign in to comment.