diff --git a/Cargo.lock b/Cargo.lock index bd5e776001..a44fbe1c0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1492,13 +1492,11 @@ version = "0.3.0-dev0" dependencies = [ "aws-credential-types", "chrono", - "common-error", "common-py-serde", "derive_more", "pyo3", "secrecy", "serde", - "serde_json", "typetag", ] diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index b821f58115..6da06bb7f6 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -638,13 +638,6 @@ class IOConfig: gcs: GCSConfig | None = None, http: HTTPConfig | None = None, ): ... - @staticmethod - def from_json(input: str) -> IOConfig: - """ - Recreate an IOConfig from a JSON string. - """ - ... - def replace( self, s3: S3Config | None = None, diff --git a/daft/io/config.py b/daft/io/config.py deleted file mode 100644 index 39246c1d77..0000000000 --- a/daft/io/config.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -from daft.daft import IOConfig - - -def _io_config_from_json(io_config_json: str) -> IOConfig: - """Used when deserializing a serialized IOConfig object""" - return IOConfig.from_json(io_config_json) diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index ac9076ef8a..d132d74ee8 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -1,17 +1,15 @@ [dependencies] aws-credential-types = {version = "0.55.3"} chrono = {workspace = true} -common-error = {path = "../error", default-features = false} common-py-serde = {path = "../py-serde", default-features = false} derive_more = {workspace = true} pyo3 = {workspace = true, optional = true} secrecy = {version = "0.8.0", features = ["alloc"], default-features = false} serde = {workspace = true} -serde_json = {workspace = true} typetag = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "common-py-serde/python"] +python = ["dep:pyo3", "common-py-serde/python"] [lints] workspace = true diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index aeae87a714..e161e052e6 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -8,8 +8,9 @@ use aws_credential_types::{ provider::{error::CredentialsError, ProvideCredentials}, Credentials, }; -use common_error::DaftError; -use common_py_serde::{deserialize_py_object, serialize_py_object}; +use common_py_serde::{ + deserialize_py_object, impl_bincode_py_state_serialization, serialize_py_object, +}; use pyo3::prelude::*; use serde::{Deserialize, Serialize}; @@ -131,8 +132,8 @@ pub struct GCSConfig { /// Example: /// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10), azure=AzureConfig(anonymous=True), gcs=GCSConfig(...)) /// >>> daft.read_parquet(["s3://some-path", "az://some-other-path", "gs://path3"], io_config=io_config) -#[derive(Clone, Default)] -#[pyclass] +#[derive(Clone, Default, Serialize, Deserialize)] +#[pyclass(module = "daft.daft")] pub struct IOConfig { pub config: config::IOConfig, } @@ -234,23 +235,6 @@ impl IOConfig { }) } - #[staticmethod] - pub fn from_json(input: &str) -> PyResult { - let config: config::IOConfig = serde_json::from_str(input).map_err(DaftError::from)?; - Ok(config.into()) - } - - pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (String,))> { - let io_config_module = py.import_bound(pyo3::intern!(py, "daft.io.config"))?; - let json_string = serde_json::to_string(&self.config).map_err(DaftError::from)?; - Ok(( - io_config_module - .getattr(pyo3::intern!(py, "_io_config_from_json"))? - .into(), - (json_string,), - )) - } - pub fn __hash__(&self) -> PyResult { use std::{collections::hash_map::DefaultHasher, hash::Hash}; @@ -260,6 +244,8 @@ impl IOConfig { } } +impl_bincode_py_state_serialization!(IOConfig); + #[pymethods] impl S3Config { #[allow(clippy::too_many_arguments)]