-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CHORE] Centralize pyo3 pickling around __reduce__
+ bincode macro.
#1394
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
use crate::{ | ||
datatypes::{DataType, Field, ImageMode, TimeUnit}, | ||
ffi, | ||
ffi, impl_bincode_py_state_serialization, | ||
}; | ||
use pyo3::{ | ||
class::basic::CompareOp, | ||
exceptions::PyValueError, | ||
prelude::*, | ||
types::{PyBytes, PyDict, PyString, PyTuple}, | ||
types::{PyBytes, PyDict, PyString}, | ||
PyTypeInfo, | ||
}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[pyclass] | ||
#[derive(Clone)] | ||
|
@@ -66,26 +68,14 @@ impl PyTimeUnit { | |
} | ||
} | ||
|
||
#[pyclass] | ||
#[derive(Clone)] | ||
#[pyclass(module = "daft.daft")] | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct PyDataType { | ||
pub dtype: DataType, | ||
} | ||
|
||
#[pymethods] | ||
impl PyDataType { | ||
#[new] | ||
#[pyo3(signature = (*args))] | ||
pub fn new(args: &PyTuple) -> PyResult<Self> { | ||
match args.len() { | ||
0 => Ok(DataType::new_null().into()), | ||
_ => Err(PyValueError::new_err(format!( | ||
"expected no arguments to make new PyDataType, got : {}", | ||
args.len() | ||
))), | ||
} | ||
} | ||
|
||
pub fn __repr__(&self) -> PyResult<String> { | ||
Ok(format!("{}", self.dtype)) | ||
} | ||
|
@@ -283,49 +273,51 @@ impl PyDataType { | |
Ok(DataType::Python.into()) | ||
} | ||
|
||
pub fn to_arrow(&self, cast_tensor_type_for_ray: Option<bool>) -> PyResult<PyObject> { | ||
Python::with_gil(|py| { | ||
let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?; | ||
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false); | ||
match (&self.dtype, cast_tensor_to_ray_type) { | ||
(DataType::FixedShapeTensor(dtype, shape), false) => Ok( | ||
if py | ||
.import(pyo3::intern!(py, "daft.utils"))? | ||
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))? | ||
.call0()? | ||
.extract()? | ||
{ | ||
pyarrow | ||
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))? | ||
.call1(( | ||
Self { | ||
dtype: *dtype.clone(), | ||
} | ||
.to_arrow(None)?, | ||
pyo3::types::PyTuple::new(py, shape.clone()), | ||
))? | ||
.to_object(py) | ||
} else { | ||
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the | ||
// canonical tensor extension type. | ||
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? | ||
}, | ||
), | ||
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py | ||
.import(pyo3::intern!(py, "ray.data.extensions"))? | ||
.getattr(pyo3::intern!(py, "ArrowTensorType"))? | ||
.call1(( | ||
pyo3::types::PyTuple::new(py, shape.clone()), | ||
Self { | ||
dtype: *dtype.clone(), | ||
} | ||
.to_arrow(None)?, | ||
))? | ||
.to_object(py)), | ||
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? | ||
.getattr(py, pyo3::intern!(py, "type")), | ||
} | ||
}) | ||
pub fn to_arrow( | ||
&self, | ||
py: Python, | ||
cast_tensor_type_for_ray: Option<bool>, | ||
) -> PyResult<PyObject> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a bad diff? what's different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, this is a driveby! This was me noticing that we were redundantly capturing the GIL with |
||
let pyarrow = py.import(pyo3::intern!(py, "pyarrow"))?; | ||
let cast_tensor_to_ray_type = cast_tensor_type_for_ray.unwrap_or(false); | ||
match (&self.dtype, cast_tensor_to_ray_type) { | ||
(DataType::FixedShapeTensor(dtype, shape), false) => Ok( | ||
if py | ||
.import(pyo3::intern!(py, "daft.utils"))? | ||
.getattr(pyo3::intern!(py, "pyarrow_supports_fixed_shape_tensor"))? | ||
.call0()? | ||
.extract()? | ||
{ | ||
pyarrow | ||
.getattr(pyo3::intern!(py, "fixed_shape_tensor"))? | ||
.call1(( | ||
Self { | ||
dtype: *dtype.clone(), | ||
} | ||
.to_arrow(py, None)?, | ||
pyo3::types::PyTuple::new(py, shape.clone()), | ||
))? | ||
.to_object(py) | ||
} else { | ||
// Fall back to default Daft super extension representation if installed pyarrow doesn't have the | ||
// canonical tensor extension type. | ||
ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? | ||
}, | ||
), | ||
(DataType::FixedShapeTensor(dtype, shape), true) => Ok(py | ||
.import(pyo3::intern!(py, "ray.data.extensions"))? | ||
.getattr(pyo3::intern!(py, "ArrowTensorType"))? | ||
.call1(( | ||
pyo3::types::PyTuple::new(py, shape.clone()), | ||
Self { | ||
dtype: *dtype.clone(), | ||
} | ||
.to_arrow(py, None)?, | ||
))? | ||
.to_object(py)), | ||
(_, _) => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? | ||
.getattr(py, pyo3::intern!(py, "type")), | ||
} | ||
} | ||
|
||
pub fn is_image(&self) -> PyResult<bool> { | ||
|
@@ -366,20 +358,6 @@ impl PyDataType { | |
Ok(DataType::from_json(serialized)?.into()) | ||
} | ||
|
||
pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { | ||
match state.extract::<&PyBytes>(py) { | ||
Ok(s) => { | ||
self.dtype = bincode::deserialize(s.as_bytes()).unwrap(); | ||
Ok(()) | ||
} | ||
Err(e) => Err(e), | ||
} | ||
} | ||
|
||
pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { | ||
Ok(PyBytes::new(py, &bincode::serialize(&self.dtype).unwrap()).to_object(py)) | ||
} | ||
|
||
pub fn __hash__(&self) -> u64 { | ||
use std::collections::hash_map::DefaultHasher; | ||
use std::hash::Hash; | ||
|
@@ -390,6 +368,8 @@ impl PyDataType { | |
} | ||
} | ||
|
||
impl_bincode_py_state_serialization!(PyDataType); | ||
|
||
impl From<DataType> for PyDataType { | ||
fn from(value: DataType) -> Self { | ||
PyDataType { dtype: value } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,18 @@ | ||
use pyo3::{ | ||
exceptions::PyValueError, | ||
prelude::*, | ||
types::{PyBytes, PyTuple}, | ||
}; | ||
use pyo3::{prelude::*, types::PyBytes, PyTypeInfo}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use super::datatype::PyDataType; | ||
use crate::datatypes::{self, DataType, Field}; | ||
use crate::datatypes; | ||
use crate::impl_bincode_py_state_serialization; | ||
|
||
#[pyclass] | ||
#[derive(Clone)] | ||
#[pyclass(module = "daft.daft")] | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct PyField { | ||
pub field: datatypes::Field, | ||
} | ||
|
||
#[pymethods] | ||
impl PyField { | ||
#[new] | ||
#[pyo3(signature = (*args))] | ||
pub fn new(args: &PyTuple) -> PyResult<Self> { | ||
match args.len() { | ||
0 => Ok(Field::new("null", DataType::new_null()).into()), | ||
_ => Err(PyValueError::new_err(format!( | ||
"expected no arguments to make new PyDataType, got : {}", | ||
args.len() | ||
))), | ||
} | ||
} | ||
|
||
pub fn name(&self) -> PyResult<String> { | ||
Ok(self.field.name.clone()) | ||
} | ||
|
@@ -38,22 +24,10 @@ impl PyField { | |
pub fn eq(&self, other: &PyField) -> PyResult<bool> { | ||
Ok(self.field.eq(&other.field)) | ||
} | ||
|
||
pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { | ||
match state.extract::<&PyBytes>(py) { | ||
Ok(s) => { | ||
self.field = bincode::deserialize(s.as_bytes()).unwrap(); | ||
Ok(()) | ||
} | ||
Err(e) => Err(e), | ||
} | ||
} | ||
|
||
pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { | ||
Ok(PyBytes::new(py, &bincode::serialize(&self.field).unwrap()).to_object(py)) | ||
} | ||
} | ||
|
||
impl_bincode_py_state_serialization!(PyField); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this might be a good macro to implement as a derive so we can |
||
|
||
impl From<datatypes::Field> for PyField { | ||
fn from(field: datatypes::Field) -> Self { | ||
PyField { field } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make an attr that is
daftpyclass=pyclass(module = "daft.daft")
and use that everywhere?