Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Centralize pyo3 pickling around __reduce__ + bincode macro. #1394

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,14 @@ class PyDataType:
def is_equal(self, other: Any) -> builtins.bool: ...
@staticmethod
def from_json(serialized: str) -> PyDataType: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def __hash__(self) -> int: ...

class PyField:
def name(self) -> str: ...
def dtype(self) -> PyDataType: ...
def eq(self, other: PyField) -> bool: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...

class PySchema:
def __getitem__(self, name: str) -> PyField: ...
Expand All @@ -505,8 +503,7 @@ class PySchema:
def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ...
@staticmethod
def from_fields(fields: list[PyField]) -> PySchema: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def __repr__(self) -> str: ...
def _repr_html_(self) -> str: ...

Expand Down Expand Up @@ -548,8 +545,7 @@ class PyExpr:
def to_field(self, schema: PySchema) -> PyField: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def __reduce__(self) -> tuple: ...
def is_nan(self) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
Expand Down
8 changes: 2 additions & 6 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,8 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> builtins.bool:
return isinstance(other, DataType) and self._dtype.is_equal(other._dtype)

def __getstate__(self) -> bytes:
return self._dtype.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._dtype = PyDataType.__new__(PyDataType)
self._dtype.__setstate__(state)
def __reduce__(self) -> tuple:
return DataType._from_pydatatype, (self._dtype,)

def __hash__(self) -> int:
return self._dtype.__hash__()
Expand Down
8 changes: 2 additions & 6 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,8 @@ def _to_field(self, schema: Schema) -> Field:
def __hash__(self) -> int:
return self._expr.__hash__()

def __getstate__(self) -> bytes:
return self._expr.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._expr = _PyExpr.__new__(_PyExpr)
self._expr.__setstate__(state)
def __reduce__(self) -> tuple:
return Expression._from_pyexpr, (self._expr,)

###
# Helper methods required by optimizer:
Expand Down
8 changes: 2 additions & 6 deletions daft/logical/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ def union(self, other: Schema) -> Schema:

return Schema._from_pyschema(self._schema.union(other._schema))

def __getstate__(self) -> bytes:
return self._schema.__getstate__()

def __setstate__(self, state: bytes) -> None:
self._schema = _PySchema.__new__(_PySchema)
self._schema.__setstate__(state)
def __reduce__(self) -> tuple:
return Schema._from_pyschema, (self._schema,)

@classmethod
def from_parquet(
Expand Down
19 changes: 2 additions & 17 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#[cfg(feature = "python")]
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
exceptions::PyValueError, prelude::*, types::PyBytes, PyObject, PyTypeInfo, ToPyObject,
};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
Expand All @@ -20,7 +18,7 @@ use common_error::{DaftError, DaftResult};
/// | Null - Count only null values.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum CountMode {
All = 1,
Valid = 2,
Expand All @@ -30,19 +28,6 @@ pub enum CountMode {
#[cfg(feature = "python")]
#[pymethods]
impl CountMode {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
// Create dummy variant, to be overridden by __setstate__.
0 => Ok(Self::All),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new JoinType, got : {}",
args.len()
))),
}
}

/// Create a CountMode from its string representation.
///
/// Args:
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/image_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use common_error::{DaftError, DaftResult};
/// Supported image formats for Daft's I/O layer.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make an attr that is daftpyclass=pyclass(module = "daft.daft") and use that everywhere?

pub enum ImageFormat {
PNG,
JPEG,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/image_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use common_error::{DaftError, DaftResult};
/// | RGBA32F - 32-bit floating RGB + alpha
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, FromPrimitive)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum ImageMode {
L = 1,
LA = 2,
Expand Down
126 changes: 53 additions & 73 deletions src/daft-core/src/python/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::{
datatypes::{DataType, Field, ImageMode, TimeUnit},
ffi,
ffi, impl_bincode_py_state_serialization,
};
use pyo3::{
class::basic::CompareOp,
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyDict, PyString, PyTuple},
types::{PyBytes, PyDict, PyString},
PyTypeInfo,
};
use serde::{Deserialize, Serialize};

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -66,26 +68,14 @@ impl PyTimeUnit {
}
}

#[pyclass]
#[derive(Clone)]
#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyDataType {
pub dtype: DataType,
}

#[pymethods]
impl PyDataType {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(DataType::new_null().into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.dtype))
}
Expand Down Expand Up @@ -283,49 +273,51 @@ impl PyDataType {
Ok(DataType::Python.into())
}

pub fn to_arrow(&self, cast_tensor_type_for_ray: Option<bool>) -> PyResult<PyObject> {
Python::with_gil(|py| {
let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?;
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false);
match (&self.dtype, cast_tensor_to_ray_type) {
(DataType::FixedShapeTensor(dtype, shape), false) => Ok(
if py
.import(pyo3::intern!(py, "daft.utils"))?
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))?
.call0()?
.extract()?
{
pyarrow
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))?
.call1((
Self {
dtype: *dtype.clone(),
}
.to_arrow(None)?,
pyo3::types::PyTuple::new(py, shape.clone()),
))?
.to_object(py)
} else {
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the
// canonical tensor extension type.
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
},
),
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py
.import(pyo3::intern!(py, "ray.data.extensions"))?
.getattr(pyo3::intern!(py, "ArrowTensorType"))?
.call1((
pyo3::types::PyTuple::new(py, shape.clone()),
Self {
dtype: *dtype.clone(),
}
.to_arrow(None)?,
))?
.to_object(py)),
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
.getattr(py, pyo3::intern!(py, "type")),
}
})
pub fn to_arrow(
&self,
py: Python,
cast_tensor_type_for_ray: Option<bool>,
) -> PyResult<PyObject> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bad diff? what's different

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is a driveby! This was me noticing that we were redundantly capturing the GIL with Python::with_gil(|py| ...) when pyo3 methods already have the GIL, and the py: Python marker token can be transparently injected into the method by specifying it as an argument.

let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?;
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false);
match (&self.dtype, cast_tensor_to_ray_type) {
(DataType::FixedShapeTensor(dtype, shape), false) => Ok(
if py
.import(pyo3::intern!(py, "daft.utils"))?
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))?
.call0()?
.extract()?
{
pyarrow
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))?
.call1((
Self {
dtype: *dtype.clone(),
}
.to_arrow(py, None)?,
pyo3::types::PyTuple::new(py, shape.clone()),
))?
.to_object(py)
} else {
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the
// canonical tensor extension type.
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
},
),
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py
.import(pyo3::intern!(py, "ray.data.extensions"))?
.getattr(pyo3::intern!(py, "ArrowTensorType"))?
.call1((
pyo3::types::PyTuple::new(py, shape.clone()),
Self {
dtype: *dtype.clone(),
}
.to_arrow(py, None)?,
))?
.to_object(py)),
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)?
.getattr(py, pyo3::intern!(py, "type")),
}
}

pub fn is_image(&self) -> PyResult<bool> {
Expand Down Expand Up @@ -366,20 +358,6 @@ impl PyDataType {
Ok(DataType::from_json(serialized)?.into())
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.dtype = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.dtype).unwrap()).to_object(py))
}

pub fn __hash__(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
Expand All @@ -390,6 +368,8 @@ impl PyDataType {
}
}

impl_bincode_py_state_serialization!(PyDataType);

impl From<DataType> for PyDataType {
fn from(value: DataType) -> Self {
PyDataType { dtype: value }
Expand Down
42 changes: 8 additions & 34 deletions src/daft-core/src/python/field.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
};
use pyo3::{prelude::*, types::PyBytes, PyTypeInfo};
use serde::{Deserialize, Serialize};

use super::datatype::PyDataType;
use crate::datatypes::{self, DataType, Field};
use crate::datatypes;
use crate::impl_bincode_py_state_serialization;

#[pyclass]
#[derive(Clone)]
#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyField {
pub field: datatypes::Field,
}

#[pymethods]
impl PyField {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
0 => Ok(Field::new("null", DataType::new_null()).into()),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new PyDataType, got : {}",
args.len()
))),
}
}

pub fn name(&self) -> PyResult<String> {
Ok(self.field.name.clone())
}
Expand All @@ -38,22 +24,10 @@ impl PyField {
pub fn eq(&self, other: &PyField) -> PyResult<bool> {
Ok(self.field.eq(&other.field))
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.field = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.field).unwrap()).to_object(py))
}
}

impl_bincode_py_state_serialization!(PyField);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be a good macro to implement as a derive so we can #[derive(PySerDe)]


impl From<datatypes::Field> for PyField {
fn from(field: datatypes::Field) -> Self {
PyField { field }
Expand Down
Loading